diff --git a/server/Makefile b/server/Makefile index ae89157e06..b21aae081e 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,4 +1,4 @@ -.PHONY: build package run stop run-client run-server run-haserver stop-haserver stop-client stop-server restart restart-server restart-client restart-haserver start-docker update-docker clean-dist clean nuke check-style check-client-style check-server-style check-unit-tests test dist run-client-tests setup-run-client-tests cleanup-run-client-tests test-client build-linux build-osx build-windows package-prep package-linux package-osx package-windows internal-test-web-client vet run-server-for-web-client-tests diff-config prepackaged-plugins prepackaged-binaries test-server test-server-ee test-server-quick test-server-race test-mmctl-unit test-mmctl-e2e test-mmctl test-mmctl-coverage mmctl-build mmctl-docs new-migration migrations-extract +.PHONY: build package run stop run-client run-server run-haserver stop-haserver stop-client stop-server restart restart-server restart-client restart-haserver start-docker update-docker clean-dist clean nuke check-style check-client-style check-server-style check-unit-tests test dist run-client-tests setup-run-client-tests cleanup-run-client-tests test-client build-linux build-osx build-windows package-prep package-linux package-osx package-windows internal-test-web-client vet run-server-for-web-client-tests diff-config prepackaged-plugins prepackaged-binaries test-server test-server-ee test-server-quick test-server-race test-mmctl-unit test-mmctl-e2e test-mmctl test-mmctl-coverage mmctl-build mmctl-docs new-migration migrations-extract test-public mocks-public ROOT := $(dir $(abspath $(lastword $(MAKEFILE_LIST)))) @@ -210,6 +210,7 @@ endif include config.mk include build/*.mk +include public/Makefile LDFLAGS += -X "github.com/mattermost/mattermost/server/public/model.MockCWS=$(MM_ENABLE_CWS_MOCK)" @@ -405,7 +406,7 @@ mmctl-mocks: ## Creates mocks for mmctl pluginapi: ## Generates api and hooks glue code for plugins cd ./public && $(GO) generate $(GOFLAGS) ./plugin -mocks: store-mocks telemetry-mocks filestore-mocks ldap-mocks plugin-mocks einterfaces-mocks searchengine-mocks sharedchannel-mocks misc-mocks email-mocks platform-mocks mmctl-mocks +mocks: store-mocks telemetry-mocks filestore-mocks ldap-mocks plugin-mocks einterfaces-mocks searchengine-mocks sharedchannel-mocks misc-mocks email-mocks platform-mocks mmctl-mocks mocks-public layers: app-layers store-layers pluginapi diff --git a/server/channels/app/plugin_api_tests/test_db_driver/main.go b/server/channels/app/plugin_api_tests/test_db_driver/main.go index 731fb0fc74..bbf50f4267 100644 --- a/server/channels/app/plugin_api_tests/test_db_driver/main.go +++ b/server/channels/app/plugin_api_tests/test_db_driver/main.go @@ -9,10 +9,10 @@ import ( "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/shared/driver" "github.com/mattermost/mattermost/server/v8/channels/app/plugin_api_tests" "github.com/mattermost/mattermost/server/v8/channels/store/sqlstore" "github.com/mattermost/mattermost/server/v8/channels/store/storetest" - "github.com/mattermost/mattermost/server/v8/platform/shared/driver" ) type MyPlugin struct { diff --git a/server/public/Makefile b/server/public/Makefile new file mode 100644 index 0000000000..85c96e86fc --- /dev/null +++ b/server/public/Makefile @@ -0,0 +1,14 @@ +test-public: gotestsum + $(GOBIN)/gotestsum ./public/... -- $(GOFLAGS) + +## Generates mock golang interfaces for testing +mocks-public: + $(GO) install github.com/golang/mock/mockgen@v1.6.0 + $(GOBIN)/mockgen -destination public/pluginapi/experimental/panel/mocks/mock_panel.go -package mock_panel github.com/mattermost/mattermost/server/public/pluginapi/experimental/panel Panel + $(GOBIN)/mockgen -destination public/pluginapi/experimental/panel/mocks/mock_panelStore.go -package mock_panel github.com/mattermost/mattermost/server/public/pluginapi/experimental/panel Store + $(GOBIN)/mockgen -destination public/pluginapi/experimental/panel/mocks/mock_setting.go -package mock_panel github.com/mattermost/mattermost/server/public/pluginapi/experimental/panel/settings Setting + $(GOBIN)/mockgen -destination public/pluginapi/experimental/bot/mocks/mock_bot.go -package mock_bot github.com/mattermost/mattermost/server/public/pluginapi/experimental/bot Bot + $(GOBIN)/mockgen -destination public/pluginapi/experimental/bot/mocks/mock_logger.go -package mock_bot github.com/mattermost/mattermost/server/public/pluginapi/experimental/bot/logger Logger + $(GOBIN)/mockgen -destination public/pluginapi/experimental/bot/mocks/mock_poster.go -package mock_bot github.com/mattermost/mattermost/server/public/pluginapi/experimental/bot/poster Poster + $(GOBIN)/mockgen -destination public/pluginapi/experimental/oauther/mocks/mock_oauther.go -package mock_oauther github.com/mattermost/mattermost/server/public/pluginapi/experimental/oauther OAuther + $(GOBIN)/mockgen -destination public/pluginapi/experimental/bot/poster/mock_import/mock_postapi.go -package mock_import github.com/mattermost/mattermost/server/public/pluginapi/experimental/bot/poster PostAPI diff --git a/server/public/go.mod b/server/public/go.mod index dcf3bbca85..c5911168bc 100644 --- a/server/public/go.mod +++ b/server/public/go.mod @@ -4,9 +4,12 @@ go 1.19 require ( github.com/blang/semver v3.5.1+incompatible + github.com/blang/semver/v4 v4.0.0 github.com/dyatlov/go-opengraph/opengraph v0.0.0-20220524092352-606d7b1e5f8a github.com/francoispqt/gojay v1.2.13 github.com/go-sql-driver/mysql v1.7.1 + github.com/golang/mock v1.6.0 + github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 github.com/graph-gophers/graphql-go v1.5.1-0.20230110080634-edea822f558a github.com/hashicorp/go-hclog v1.5.0 @@ -15,18 +18,23 @@ require ( github.com/mattermost/go-i18n v1.11.1-0.20211013152124-5c415071e404 github.com/mattermost/ldap v0.0.0-20201202150706-ee0e6284187d github.com/mattermost/logr/v2 v2.0.16 + github.com/nicksnyder/go-i18n/v2 v2.0.3 github.com/pborman/uuid v1.2.1 github.com/pkg/errors v0.9.1 + github.com/rudderlabs/analytics-go v3.3.3+incompatible + github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.4 github.com/tinylib/msgp v1.1.8 github.com/vmihailenco/msgpack/v5 v5.3.5 golang.org/x/crypto v0.10.0 + golang.org/x/oauth2 v0.7.0 golang.org/x/text v0.10.0 golang.org/x/tools v0.10.0 gopkg.in/yaml.v2 v2.4.0 ) require ( + github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/fatih/color v1.15.0 // indirect github.com/go-asn1-ber/asn1-ber v1.5.4 // indirect @@ -41,13 +49,20 @@ require ( github.com/pelletier/go-toml v1.9.5 // indirect github.com/philhofer/fwd v1.1.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.8.0 // indirect + github.com/segmentio/backo-go v1.0.1 // indirect github.com/stretchr/objx v0.5.0 // indirect + github.com/tidwall/gjson v1.14.3 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/wiggin77/merror v1.0.5 // indirect github.com/wiggin77/srslog v1.0.1 // indirect + github.com/xtgo/uuid v0.0.0-20140804021211-a0b114877d4c // indirect golang.org/x/mod v0.11.0 // indirect golang.org/x/net v0.11.0 // indirect golang.org/x/sys v0.9.0 // indirect + google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230629202037-9506855d4529 // indirect google.golang.org/grpc v1.56.1 // indirect google.golang.org/protobuf v1.31.0 // indirect diff --git a/server/public/go.sum b/server/public/go.sum index 9442becd87..f1076d7bac 100644 --- a/server/public/go.sum +++ b/server/public/go.sum @@ -7,11 +7,17 @@ dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBr dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= +github.com/BurntSushi/toml v0.3.0/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ= github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= +github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= +github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= @@ -46,6 +52,8 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfU github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= @@ -66,6 +74,8 @@ github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/graph-gophers/graphql-go v1.5.1-0.20230110080634-edea822f558a h1:i0+Se9S+2zL5CBxJouqn2Ej6UQMwH1c57ZB6DVnqck4= @@ -119,6 +129,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= +github.com/nicksnyder/go-i18n/v2 v2.0.3 h1:ks/JkQiOEhhuF6jpNvx+Wih1NIiXzUnZeZVnJuI8R8M= +github.com/nicksnyder/go-i18n/v2 v2.0.3/go.mod h1:oDab7q8XCYMRlcrBnaY/7B1eOectbvj6B1UPBT+p5jo= github.com/oklog/run v1.1.0 h1:GEenZ1cK0+q0+wsJew9qUg/DyD8k3JzYsZAi5gYi2mA= github.com/oklog/run v1.1.0/go.mod h1:sVPdnTZT1zYwAJeCMu2Th4T21pA3FPOQRfWjQlk7DVU= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= @@ -130,6 +142,7 @@ github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3v github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/philhofer/fwd v1.1.2 h1:bnDivRJ1EWPjUIRXV5KfORO897HTbpFAQddBdE8t7Gw= github.com/philhofer/fwd v1.1.2/go.mod h1:qkPdfjR2SIEbspLqpe1tO4n5yICnr2DY7mqEx2tUTP0= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -139,9 +152,14 @@ github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= +github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= +github.com/rudderlabs/analytics-go v3.3.3+incompatible h1:OG0XlKoXfr539e2t1dXtTB+Gr89uFW+OUNQBVhHIIBY= +github.com/rudderlabs/analytics-go v3.3.3+incompatible/go.mod h1:LF8/ty9kUX4PTY3l5c97K3nZZaX5Hwsvt+NBaRL/f30= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= +github.com/segmentio/backo-go v1.0.1 h1:68RQccglxZeyURy93ASB/2kc9QudzgIDexJ927N++y4= +github.com/segmentio/backo-go v1.0.1/go.mod h1:9/Rh6yILuLysoQnZ2oNooD2g7aBnvM7r/fNVxRNWfBc= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM= @@ -165,6 +183,8 @@ github.com/shurcooL/reactions v0.0.0-20181006231557-f2e0b4ca5b82/go.mod h1:TCR1l github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537/go.mod h1:QJTqeLYEDaXHZDBsXlPCDqdhQuJkuw4NOtaxYe3xii4= github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5kWdCj2z2KEozexVbfEZIWiTjhE0+UjmZgPqehw= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -174,12 +194,19 @@ github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpE github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= +github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= +github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tinylib/msgp v1.1.8 h1:FCXC1xanKO4I8plpHGH2P7koL/RzZs12l/+r7vakfm0= github.com/tinylib/msgp v1.1.8/go.mod h1:qkpG+2ldGg4xRFmx+jfTvZPxfGFhi64BcnL9vkCm/Tw= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= @@ -192,6 +219,9 @@ github.com/wiggin77/merror v1.0.5 h1:P+lzicsn4vPMycAf2mFf7Zk6G9eco5N+jB1qJ2XW3ME github.com/wiggin77/merror v1.0.5/go.mod h1:H2ETSu7/bPE0Ymf4bEwdUoo73OOEkdClnoRisfw0Nm0= github.com/wiggin77/srslog v1.0.1 h1:gA2XjSMy3DrRdX9UqLuDtuVAAshb8bE1NhX1YK0Qe+8= github.com/wiggin77/srslog v1.0.1/go.mod h1:fehkyYDq1QfuYn60TDPu9YdY2bB85VUW2mvN1WynEls= +github.com/xtgo/uuid v0.0.0-20140804021211-a0b114877d4c h1:3lbZUMbMiGUW/LMkfsEABsc5zNT9+b1CvsJx47JzJ8g= +github.com/xtgo/uuid v0.0.0-20140804021211-a0b114877d4c/go.mod h1:UrdRz5enIKZ63MEE3IF9l2/ebyx59GyGgPi+tICQdmM= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= go.opentelemetry.io/otel v1.6.3/go.mod h1:7BgNga5fNlF/iZjG06hM3yofffp0ofKCDwSXx1GC4dI= @@ -201,6 +231,8 @@ golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+ golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190506204251-e1dfcc566284/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= @@ -208,6 +240,7 @@ golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= @@ -219,9 +252,14 @@ golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20220520000938-2e3eb7b945c2/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= @@ -231,12 +269,15 @@ golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAG golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/oauth2 v0.7.0 h1:qe6s0zUXlPX80/dITx3440hWZ7GwMwgDDyrSGTPJG/g= +golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4= golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= @@ -245,15 +286,20 @@ golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190316082340-a2f829d7f35f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -265,6 +311,7 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= @@ -277,13 +324,17 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ= golang.org/x/tools v0.10.0 h1:tvDr/iQoUqNdohiYm0LmmKcBk+q86lb9EprIUFhHHGg= golang.org/x/tools v0.10.0/go.mod h1:UJwyiVBsOA2uwvK/e5OY3GTpDUJriEd+/YlqAwLPmyM= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= @@ -291,6 +342,8 @@ google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9Ywl google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= +google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= diff --git a/server/public/model/manifest.go b/server/public/model/manifest.go index 66eb218e3b..4b865fc159 100644 --- a/server/public/model/manifest.go +++ b/server/public/model/manifest.go @@ -191,7 +191,8 @@ type Manifest struct { // RequiredConfig defines any required server configuration fields for the plugin to function properly. // - // Use the pluginapi.Configuration.CheckRequiredServerConfiguration method to enforce this. + // Deprecated: The required server configuration fields should be checked using custom code. + // This field will get removed in the next major release. RequiredConfig *Config `json:"required_configuration,omitempty" yaml:"required_configuration,omitempty"` } diff --git a/server/public/pluginapi/bot.go b/server/public/pluginapi/bot.go new file mode 100644 index 0000000000..74a7710b74 --- /dev/null +++ b/server/public/pluginapi/bot.go @@ -0,0 +1,202 @@ +package pluginapi + +import ( + "os" + "path/filepath" + + "github.com/pkg/errors" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/pluginapi/cluster" +) + +const ( + internalKeyPrefix = "mmi_" + botUserKey = internalKeyPrefix + "botid" + botEnsureMutexKey = internalKeyPrefix + "bot_ensure" +) + +// BotService exposes methods to manipulate bots. +type BotService struct { + api plugin.API +} + +// Get returns a bot by botUserID. +// +// Minimum server version: 5.10 +func (b *BotService) Get(botUserID string, includeDeleted bool) (*model.Bot, error) { + bot, appErr := b.api.GetBot(botUserID, includeDeleted) + + return bot, normalizeAppErr(appErr) +} + +// BotListOption is an option to configure a bot List() request. +type BotListOption func(*model.BotGetOptions) + +// BotOwner option configures bot list request to only retrieve the bots that matches with +// owner's id. +func BotOwner(id string) BotListOption { + return func(o *model.BotGetOptions) { + o.OwnerId = id + } +} + +// BotIncludeDeleted option configures bot list request to also retrieve the deleted bots. +func BotIncludeDeleted() BotListOption { + return func(o *model.BotGetOptions) { + o.IncludeDeleted = true + } +} + +// BotOnlyOrphans option configures bot list request to only retrieve orphan bots. +func BotOnlyOrphans() BotListOption { + return func(o *model.BotGetOptions) { + o.OnlyOrphaned = true + } +} + +// List returns a list of bots by page, count and options. +// +// Minimum server version: 5.10 +func (b *BotService) List(page, perPage int, options ...BotListOption) ([]*model.Bot, error) { + opts := &model.BotGetOptions{ + Page: page, + PerPage: perPage, + } + for _, o := range options { + o(opts) + } + bots, appErr := b.api.GetBots(opts) + + return bots, normalizeAppErr(appErr) +} + +// Create creates the bot and corresponding user. +// +// Minimum server version: 5.10 +func (b *BotService) Create(bot *model.Bot) error { + createdBot, appErr := b.api.CreateBot(bot) + if appErr != nil { + return normalizeAppErr(appErr) + } + + *bot = *createdBot + + return nil +} + +// Patch applies the given patch to the bot and corresponding user. +// +// Minimum server version: 5.10 +func (b *BotService) Patch(botUserID string, botPatch *model.BotPatch) (*model.Bot, error) { + bot, appErr := b.api.PatchBot(botUserID, botPatch) + + return bot, normalizeAppErr(appErr) +} + +// UpdateActive marks a bot as active or inactive, along with its corresponding user. +// +// Minimum server version: 5.10 +func (b *BotService) UpdateActive(botUserID string, isActive bool) (*model.Bot, error) { + bot, appErr := b.api.UpdateBotActive(botUserID, isActive) + + return bot, normalizeAppErr(appErr) +} + +// DeletePermanently permanently deletes a bot and its corresponding user. +// +// Minimum server version: 5.10 +func (b *BotService) DeletePermanently(botUserID string) error { + return normalizeAppErr(b.api.PermanentDeleteBot(botUserID)) +} + +type ensureBotOptions struct { + ProfileImagePath string +} + +type EnsureBotOption func(*ensureBotOptions) + +func ProfileImagePath(path string) EnsureBotOption { + return func(args *ensureBotOptions) { + args.ProfileImagePath = path + } +} + +// EnsureBot either returns an existing bot user matching the given bot, or creates a bot user from the given bot. +// A profile image or icon image may be optionally passed in to be set for the existing or newly created bot. +// Returns the id of the resulting bot. +// EnsureBot can safely be called multiple instances of a plugin concurrently. +// +// Minimum server version: 5.10 +func (b *BotService) EnsureBot(bot *model.Bot, options ...EnsureBotOption) (string, error) { + m, err := cluster.NewMutex(b.api, botEnsureMutexKey) + if err != nil { + return "", errors.Wrap(err, "failed to create mutex") + } + + return b.ensureBot(m, bot, options...) +} + +type mutex interface { + Lock() + Unlock() +} + +// TODO: this utility function is also used by the product framework. We should move this to mattermost-server and share +// the code to maintain consistent behavior. Ticket: MM-44953 +func (b *BotService) ensureBot(m mutex, bot *model.Bot, options ...EnsureBotOption) (string, error) { + err := ensureServerVersion(b.api, "5.10.0") + if err != nil { + return "", errors.Wrap(err, "failed to ensure bot") + } + + // Default options + o := &ensureBotOptions{ + ProfileImagePath: "", + } + + for _, setter := range options { + setter(o) + } + + botID, err := b.ensureBotUser(m, bot) + if err != nil { + return "", err + } + + if o.ProfileImagePath != "" { + imageBytes, err := b.readFile(o.ProfileImagePath) + if err != nil { + return "", errors.Wrap(err, "failed to read profile image") + } + appErr := b.api.SetProfileImage(botID, imageBytes) + if appErr != nil { + return "", errors.Wrap(appErr, "failed to set profile image") + } + } + + return botID, nil +} + +func (b *BotService) ensureBotUser(m mutex, bot *model.Bot) (retBotID string, retErr error) { + // Lock to prevent two plugins from racing to create the bot account + m.Lock() + defer m.Unlock() + + return b.api.EnsureBotUser(bot) +} + +func (b *BotService) readFile(path string) ([]byte, error) { + bundlePath, err := b.api.GetBundlePath() + if err != nil { + return nil, errors.Wrap(err, "failed to get bundle path") + } + + imageBytes, err := os.ReadFile(filepath.Join(bundlePath, path)) + if err != nil { + return nil, errors.Wrap(err, "failed to read image") + } + + return imageBytes, nil +} diff --git a/server/public/pluginapi/bot_test.go b/server/public/pluginapi/bot_test.go new file mode 100644 index 0000000000..cbcf321ebc --- /dev/null +++ b/server/public/pluginapi/bot_test.go @@ -0,0 +1,368 @@ +package pluginapi + +import ( + "net/http" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" +) + +func TestCreateBot(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := NewClient(api, &plugintest.Driver{}) + + api.On("CreateBot", &model.Bot{Username: "1"}).Return(&model.Bot{Username: "1", UserId: "2"}, nil) + + bot := &model.Bot{Username: "1"} + err := client.Bot.Create(bot) + require.NoError(t, err) + require.Equal(t, &model.Bot{Username: "1", UserId: "2"}, bot) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := NewClient(api, &plugintest.Driver{}) + + appErr := newAppError() + + api.On("CreateBot", &model.Bot{Username: "1"}).Return(nil, appErr) + + bot := &model.Bot{Username: "1"} + err := client.Bot.Create(&model.Bot{Username: "1"}) + require.Equal(t, appErr, err) + require.Equal(t, &model.Bot{Username: "1"}, bot) + }) +} + +func TestUpdateBotStatus(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := NewClient(api, &plugintest.Driver{}) + + api.On("UpdateBotActive", "1", true).Return(&model.Bot{UserId: "2"}, nil) + + bot, err := client.Bot.UpdateActive("1", true) + require.NoError(t, err) + require.Equal(t, &model.Bot{UserId: "2"}, bot) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := NewClient(api, &plugintest.Driver{}) + + appErr := newAppError() + + api.On("UpdateBotActive", "1", true).Return(nil, appErr) + + bot, err := client.Bot.UpdateActive("1", true) + require.Equal(t, appErr, err) + require.Zero(t, bot) + }) +} + +func TestGetBot(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := NewClient(api, &plugintest.Driver{}) + + api.On("GetBot", "1", true).Return(&model.Bot{UserId: "2"}, nil) + + bot, err := client.Bot.Get("1", true) + require.NoError(t, err) + require.Equal(t, &model.Bot{UserId: "2"}, bot) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := NewClient(api, &plugintest.Driver{}) + + appErr := newAppError() + + api.On("GetBot", "1", true).Return(nil, appErr) + + bot, err := client.Bot.Get("1", true) + require.Equal(t, appErr, err) + require.Zero(t, bot) + }) +} + +func TestListBot(t *testing.T) { + tests := []struct { + name string + page, count int + options []BotListOption + expectedOptions *model.BotGetOptions + bots []*model.Bot + err error + }{ + { + "owner filter", + 1, + 2, + []BotListOption{ + BotOwner("3"), + }, + &model.BotGetOptions{ + Page: 1, + PerPage: 2, + OwnerId: "3", + }, + []*model.Bot{ + {UserId: "4"}, + {UserId: "5"}, + }, + nil, + }, + { + "all filter", + 1, + 2, + []BotListOption{ + BotOwner("3"), + BotIncludeDeleted(), + BotOnlyOrphans(), + }, + &model.BotGetOptions{ + Page: 1, + PerPage: 2, + OwnerId: "3", + IncludeDeleted: true, + OnlyOrphaned: true, + }, + []*model.Bot{ + {UserId: "4"}, + }, + nil, + }, + { + "no filter", + 1, + 2, + []BotListOption{}, + &model.BotGetOptions{ + Page: 1, + PerPage: 2, + }, + []*model.Bot{ + {UserId: "4"}, + }, + nil, + }, + { + "app error", + 1, + 2, + []BotListOption{ + BotOwner("3"), + }, + &model.BotGetOptions{ + Page: 1, + PerPage: 2, + OwnerId: "3", + }, + nil, + newAppError(), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + api := &plugintest.API{} + client := NewClient(api, &plugintest.Driver{}) + + api.On("GetBots", test.expectedOptions).Return(test.bots, test.err) + + bots, err := client.Bot.List(test.page, test.count, test.options...) + if test.err != nil { + require.Equal(t, test.err.Error(), err.Error(), test.name) + } else { + require.NoError(t, err, test.name) + } + require.Equal(t, test.bots, bots, test.name) + + api.AssertExpectations(t) + }) + } +} + +func TestDeleteBotPermanently(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := NewClient(api, &plugintest.Driver{}) + + api.On("PermanentDeleteBot", "1").Return(nil) + + err := client.Bot.DeletePermanently("1") + require.NoError(t, err) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := NewClient(api, &plugintest.Driver{}) + + appErr := newAppError() + + api.On("PermanentDeleteBot", "1").Return(appErr) + + err := client.Bot.DeletePermanently("1") + require.Equal(t, appErr, err) + }) +} + +func TestEnsureBot(t *testing.T) { + testbot := &model.Bot{ + Username: "testbot", + DisplayName: "Test Bot", + Description: "testbotdescription", + } + + m := testMutex{} + + t.Run("server version incompatible", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := NewClient(api, &plugintest.Driver{}) + + api.On("GetServerVersion").Return("5.9.0") + + _, err := client.Bot.ensureBot(m, nil) + require.Error(t, err) + assert.Equal(t, + "failed to ensure bot: incompatible server version for plugin, minimum required version: 5.10.0, current version: 5.9.0", + err.Error(), + ) + }) + + t.Run("if bot already exists", func(t *testing.T) { + t.Run("should find and return the existing bot ID", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := NewClient(api, &plugintest.Driver{}) + + expectedBotID := model.NewId() + + api.On("GetServerVersion").Return("5.10.0") + api.On("EnsureBotUser", testbot).Return(expectedBotID, nil) + botID, err := client.Bot.ensureBot(m, testbot) + + require.NoError(t, err) + assert.Equal(t, expectedBotID, botID) + }) + + t.Run("should set the bot profile image when specified", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := NewClient(api, &plugintest.Driver{}) + + expectedBotID := model.NewId() + + profileImageFile, err := os.CreateTemp("", "profile_image") + require.NoError(t, err) + + profileImageBytes := []byte("profile image") + err = os.WriteFile(profileImageFile.Name(), profileImageBytes, 0o600) + require.NoError(t, err) + + api.On("GetBundlePath").Return("", nil) + api.On("EnsureBotUser", testbot).Return(expectedBotID, nil) + api.On("SetProfileImage", expectedBotID, profileImageBytes).Return(nil) + api.On("GetServerVersion").Return("5.10.0") + + botID, err := client.Bot.ensureBot(m, testbot, ProfileImagePath(profileImageFile.Name())) + require.NoError(t, err) + assert.Equal(t, expectedBotID, botID) + }) + + t.Run("should find and update the bot with new bot details", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := NewClient(api, &plugintest.Driver{}) + + expectedBotID := model.NewId() + expectedBotUsername := "updated_testbot" + expectedBotDisplayName := "Updated Test Bot" + expectedBotDescription := "updated testbotdescription" + + profileImageFile, err := os.CreateTemp("", "profile_image") + require.NoError(t, err) + + profileImageBytes := []byte("profile image") + err = os.WriteFile(profileImageFile.Name(), profileImageBytes, 0o600) + require.NoError(t, err) + + iconImageFile, err := os.CreateTemp("", "profile_image") + require.NoError(t, err) + + iconImageBytes := []byte("icon image") + err = os.WriteFile(iconImageFile.Name(), iconImageBytes, 0o600) + require.NoError(t, err) + + updatedTestBot := &model.Bot{ + Username: expectedBotUsername, + DisplayName: expectedBotDisplayName, + Description: expectedBotDescription, + } + api.On("GetServerVersion").Return("5.10.0") + api.On("EnsureBotUser", updatedTestBot).Return(expectedBotID, nil) + api.On("GetBundlePath").Return("", nil) + api.On("SetProfileImage", expectedBotID, profileImageBytes).Return(nil) + + botID, err := client.Bot.ensureBot(m, + updatedTestBot, + ProfileImagePath(profileImageFile.Name()), + ) + require.NoError(t, err) + assert.Equal(t, expectedBotID, botID) + }) + }) + + t.Run("if bot doesn't exist", func(t *testing.T) { + t.Run("should create bot and set the bot profile image when specified", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := NewClient(api, &plugintest.Driver{}) + + expectedBotID := model.NewId() + + profileImageFile, err := os.CreateTemp("", "profile_image") + require.NoError(t, err) + + profileImageBytes := []byte("profile image") + err = os.WriteFile(profileImageFile.Name(), profileImageBytes, 0o600) + require.NoError(t, err) + + api.On("EnsureBotUser", testbot).Return(expectedBotID, nil) + api.On("GetBundlePath").Return("", nil) + api.On("SetProfileImage", expectedBotID, profileImageBytes).Return(nil) + api.On("GetServerVersion").Return("5.10.0") + + botID, err := client.Bot.ensureBot(m, testbot, ProfileImagePath(profileImageFile.Name())) + require.NoError(t, err) + assert.Equal(t, expectedBotID, botID) + }) + }) +} + +func newAppError() *model.AppError { + return model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) +} + +type testMutex struct { +} + +func (m testMutex) Lock() {} +func (m testMutex) Unlock() {} diff --git a/server/public/pluginapi/channel.go b/server/public/pluginapi/channel.go new file mode 100644 index 0000000000..ecf9b5ad87 --- /dev/null +++ b/server/public/pluginapi/channel.go @@ -0,0 +1,286 @@ +package pluginapi + +import ( + "net/http" + "time" + + "github.com/pkg/errors" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" +) + +// ChannelService exposes methods to manipulate channels. +type ChannelService struct { + api plugin.API +} + +// Get gets a channel. +// +// Minimum server version: 5.2 +func (c *ChannelService) Get(channelID string) (*model.Channel, error) { + channel, appErr := c.api.GetChannel(channelID) + + return channel, normalizeAppErr(appErr) +} + +// GetByName gets a channel by its name, given a team id. +// +// Minimum server version: 5.2 +func (c *ChannelService) GetByName(teamID, channelName string, includeDeleted bool) (*model.Channel, error) { + channel, appErr := c.api.GetChannelByName(teamID, channelName, includeDeleted) + + return channel, normalizeAppErr(appErr) +} + +// GetDirect gets a direct message channel. +// +// Note that if the channel does not exist it will create it. +// +// Minimum server version: 5.2 +func (c *ChannelService) GetDirect(userID1, userID2 string) (*model.Channel, error) { + channel, appErr := c.api.GetDirectChannel(userID1, userID2) + + return channel, normalizeAppErr(appErr) +} + +// GetGroup gets a group message channel. +// +// Note that if the channel does not exist it will create it. +// +// Minimum server version: 5.2 +func (c *ChannelService) GetGroup(userIDs []string) (*model.Channel, error) { + channel, appErr := c.api.GetGroupChannel(userIDs) + + return channel, normalizeAppErr(appErr) +} + +// GetByNameForTeamName gets a channel by its name, given a team name. +// +// Minimum server version: 5.2 +func (c *ChannelService) GetByNameForTeamName(teamName, channelName string, includeDeleted bool) (*model.Channel, error) { + channel, appErr := c.api.GetChannelByNameForTeamName(teamName, channelName, includeDeleted) + + return channel, normalizeAppErr(appErr) +} + +// ListForTeamForUser gets a list of channels for given user ID in given team ID. +// +// Minimum server version: 5.6 +func (c *ChannelService) ListForTeamForUser(teamID, userID string, includeDeleted bool) ([]*model.Channel, error) { + channels, appErr := c.api.GetChannelsForTeamForUser(teamID, userID, includeDeleted) + + return channels, normalizeAppErr(appErr) +} + +// ListPublicChannelsForTeam gets a list of all channels. +// +// Minimum server version: 5.2 +func (c *ChannelService) ListPublicChannelsForTeam(teamID string, page, perPage int) ([]*model.Channel, error) { + channels, appErr := c.api.GetPublicChannelsForTeam(teamID, page, perPage) + + return channels, normalizeAppErr(appErr) +} + +// Search returns the channels on a team matching the provided search term. +// +// Minimum server version: 5.6 +func (c *ChannelService) Search(teamID, term string) ([]*model.Channel, error) { + channels, appErr := c.api.SearchChannels(teamID, term) + + return channels, normalizeAppErr(appErr) +} + +// Create creates a channel. +// +// Minimum server version: 5.2 +func (c *ChannelService) Create(channel *model.Channel) error { + createdChannel, appErr := c.api.CreateChannel(channel) + if appErr != nil { + return normalizeAppErr(appErr) + } + + *channel = *createdChannel + + return c.waitForChannelCreation(channel.Id) +} + +// Update updates a channel. +// +// Minimum server version: 5.2 +func (c *ChannelService) Update(channel *model.Channel) error { + updatedChannel, appErr := c.api.UpdateChannel(channel) + if appErr != nil { + return normalizeAppErr(appErr) + } + + *channel = *updatedChannel + + return nil +} + +// Delete deletes a channel. +// +// Minimum server version: 5.2 +func (c *ChannelService) Delete(channelID string) error { + return normalizeAppErr(c.api.DeleteChannel(channelID)) +} + +// GetChannelStats gets statistics for a channel. +// +// Minimum server version: 5.6 +func (c *ChannelService) GetChannelStats(channelID string) (*model.ChannelStats, error) { + channelStats, appErr := c.api.GetChannelStats(channelID) + + return channelStats, normalizeAppErr(appErr) +} + +// GetMember gets a channel membership for a user. +// +// Minimum server version: 5.2 +func (c *ChannelService) GetMember(channelID, userID string) (*model.ChannelMember, error) { + channelMember, appErr := c.api.GetChannelMember(channelID, userID) + + return channelMember, normalizeAppErr(appErr) +} + +// ListMembers gets a channel membership for all users. +// +// Minimum server version: 5.6 +func (c *ChannelService) ListMembers(channelID string, page, perPage int) ([]*model.ChannelMember, error) { + channelMembers, appErr := c.api.GetChannelMembers(channelID, page, perPage) + + return channelMembersToChannelMemberSlice(channelMembers), normalizeAppErr(appErr) +} + +// ListMembersByIDs gets a channel membership for a particular User +// +// Minimum server version: 5.6 +func (c *ChannelService) ListMembersByIDs(channelID string, userIDs []string) ([]*model.ChannelMember, error) { + channelMembers, appErr := c.api.GetChannelMembersByIds(channelID, userIDs) + + return channelMembersToChannelMemberSlice(channelMembers), normalizeAppErr(appErr) +} + +// ListMembersForUser returns all channel memberships on a team for a user. +// +// Minimum server version: 5.10 +func (c *ChannelService) ListMembersForUser(teamID, userID string, page, perPage int) ([]*model.ChannelMember, error) { + channelMembers, appErr := c.api.GetChannelMembersForUser(teamID, userID, page, perPage) + + return channelMembers, normalizeAppErr(appErr) +} + +// AddMember joins a user to a channel (as if they joined themselves). +// This means the user will not receive notifications for joining the channel. +// +// Minimum server version: 5.2 +func (c *ChannelService) AddMember(channelID, userID string) (*model.ChannelMember, error) { + channelMember, appErr := c.api.AddChannelMember(channelID, userID) + + return channelMember, normalizeAppErr(appErr) +} + +// AddUser adds a user to a channel as if the specified user had invited them. +// This means the user will receive the regular notifications for being added to the channel. +// +// Minimum server version: 5.18 +func (c *ChannelService) AddUser(channelID, userID, asUserID string) (*model.ChannelMember, error) { + channelMember, appErr := c.api.AddUserToChannel(channelID, userID, asUserID) + + return channelMember, normalizeAppErr(appErr) +} + +// DeleteMember deletes a channel membership for a user. +// +// Minimum server version: 5.2 +func (c *ChannelService) DeleteMember(channelID, userID string) error { + appErr := c.api.DeleteChannelMember(channelID, userID) + + return normalizeAppErr(appErr) +} + +// UpdateChannelMemberRoles updates a user's roles for a channel. +// +// Minimum server version: 5.2 +func (c *ChannelService) UpdateChannelMemberRoles(channelID, userID, newRoles string) (*model.ChannelMember, error) { + channelMember, appErr := c.api.UpdateChannelMemberRoles(channelID, userID, newRoles) + + return channelMember, normalizeAppErr(appErr) +} + +// UpdateChannelMemberNotifications updates a user's notification properties for a channel. +// +// Minimum server version: 5.2 +func (c *ChannelService) UpdateChannelMemberNotifications(channelID, userID string, notifications map[string]string) (*model.ChannelMember, error) { + channelMember, appErr := c.api.UpdateChannelMemberNotifications(channelID, userID, notifications) + + return channelMember, normalizeAppErr(appErr) +} + +// CreateSidebarCategory creates a new sidebar category for a set of channels. +// +// Minimum server version: 5.38 +func (c *ChannelService) CreateSidebarCategory( + userID, teamID string, newCategory *model.SidebarCategoryWithChannels) error { + category, appErr := c.api.CreateChannelSidebarCategory(userID, teamID, newCategory) + if appErr != nil { + return normalizeAppErr(appErr) + } + *newCategory = *category + + return nil +} + +// GetSidebarCategories returns sidebar categories. +// +// Minimum server version: 5.38 +func (c *ChannelService) GetSidebarCategories(userID, teamID string) (*model.OrderedSidebarCategories, error) { + categories, appErr := c.api.GetChannelSidebarCategories(userID, teamID) + + return categories, normalizeAppErr(appErr) +} + +// UpdateSidebarCategories updates the channel sidebar categories. +// +// Minimum server version: 5.38 +func (c *ChannelService) UpdateSidebarCategories( + userID, teamID string, categories []*model.SidebarCategoryWithChannels) error { + updatedCategories, appErr := c.api.UpdateChannelSidebarCategories(userID, teamID, categories) + if appErr != nil { + return normalizeAppErr(appErr) + } + copy(categories, updatedCategories) + + return nil +} + +func (c *ChannelService) waitForChannelCreation(channelID string) error { + if len(c.api.GetConfig().SqlSettings.DataSourceReplicas) == 0 { + return nil + } + + now := time.Now() + + for time.Since(now) < 1500*time.Millisecond { + time.Sleep(100 * time.Millisecond) + + if _, err := c.api.GetChannel(channelID); err == nil { + // Channel found + return nil + } else if err.StatusCode != http.StatusNotFound { + return err + } + } + + return errors.Errorf("giving up waiting for channel creation, channelID=%s", channelID) +} + +func channelMembersToChannelMemberSlice(cm model.ChannelMembers) []*model.ChannelMember { + cmp := make([]*model.ChannelMember, len(cm)) + for i := 0; i < len(cm); i++ { + cmp[i] = &(cm)[i] + } + + return cmp +} diff --git a/server/public/pluginapi/channel_test.go b/server/public/pluginapi/channel_test.go new file mode 100644 index 0000000000..2e6803ca1f --- /dev/null +++ b/server/public/pluginapi/channel_test.go @@ -0,0 +1,373 @@ +package pluginapi_test + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/mattermost/mattermost/server/public/pluginapi" +) + +func TestGetMembers(t *testing.T) { + t.Run("empty list", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetChannelMembers", "channelID", 1, 10).Return(nil, nil) + + cm, err := client.Channel.ListMembers("channelID", 1, 10) + require.NoError(t, err) + require.Empty(t, cm) + }) +} + +func TestGetTeamChannelByName(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetChannelByNameForTeamName", "1", "2", true).Return(&model.Channel{TeamId: "3"}, nil) + + channel, err := client.Channel.GetByNameForTeamName("1", "2", true) + require.NoError(t, err) + require.Equal(t, &model.Channel{TeamId: "3"}, channel) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetChannelByNameForTeamName", "1", "2", true).Return(nil, newAppError()) + + channel, err := client.Channel.GetByNameForTeamName("1", "2", true) + require.EqualError(t, err, "here: id, an error occurred") + require.Zero(t, channel) + }) +} + +func TestGetTeamUserChannels(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetChannelsForTeamForUser", "1", "2", true).Return([]*model.Channel{{TeamId: "3"}, {TeamId: "4"}}, nil) + + channels, err := client.Channel.ListForTeamForUser("1", "2", true) + require.NoError(t, err) + require.Equal(t, []*model.Channel{{TeamId: "3"}, {TeamId: "4"}}, channels) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("GetChannelsForTeamForUser", "1", "2", true).Return(nil, appErr) + + channels, err := client.Channel.ListForTeamForUser("1", "2", true) + require.Equal(t, appErr, err) + require.Len(t, channels, 0) + }) +} + +func TestGetPublicTeamChannels(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetPublicChannelsForTeam", "1", 2, 3).Return([]*model.Channel{{TeamId: "3"}, {TeamId: "4"}}, nil) + + channels, err := client.Channel.ListPublicChannelsForTeam("1", 2, 3) + require.NoError(t, err) + require.Equal(t, []*model.Channel{{TeamId: "3"}, {TeamId: "4"}}, channels) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("GetPublicChannelsForTeam", "1", 2, 3).Return(nil, appErr) + + channels, err := client.Channel.ListPublicChannelsForTeam("1", 2, 3) + require.Equal(t, appErr, err) + require.Len(t, channels, 0) + }) +} + +func TestCreateChannel(t *testing.T) { + t.Run("create channel with no replicas", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + config := &model.Config{ + SqlSettings: model.SqlSettings{ + DataSourceReplicas: []string{}, + }, + } + api.On("GetConfig").Return(config).Once() + + c := &model.Channel{ + Id: model.NewId(), + Name: "name", + DisplayName: "displayname", + } + api.On("CreateChannel", c).Return(c, nil).Once() + + err := client.Channel.Create(c) + require.NoError(t, err) + }) + + t.Run("create channel and wait once", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + config := &model.Config{ + SqlSettings: model.SqlSettings{ + DataSourceReplicas: []string{"replica1"}, + }, + } + api.On("GetConfig").Return(config).Once() + + c := &model.Channel{ + Id: model.NewId(), + Name: "name", + DisplayName: "displayname", + } + api.On("CreateChannel", c).Return(c, nil).Once() + api.On("GetChannel", c.Id).Return(c, nil).Once() + + err := client.Channel.Create(c) + require.NoError(t, err) + }) + + t.Run("create channel and wait multiple times", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + config := &model.Config{ + SqlSettings: model.SqlSettings{ + DataSourceReplicas: []string{"replica1"}, + }, + } + api.On("GetConfig").Return(config).Once() + + c := &model.Channel{ + Id: model.NewId(), + Name: "name", + DisplayName: "displayname", + } + api.On("CreateChannel", c).Return(c, nil).Once() + + notFoundErr := model.NewAppError("", "", nil, "", http.StatusNotFound) + api.On("GetChannel", c.Id).Return(c, notFoundErr).Times(3) + api.On("GetChannel", c.Id).Return(c, nil).Times(1) + + err := client.Channel.Create(c) + require.NoError(t, err) + }) + + t.Run("create channel, wait multiple times and return error", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + config := &model.Config{ + SqlSettings: model.SqlSettings{ + DataSourceReplicas: []string{"replica1"}, + }, + } + api.On("GetConfig").Return(config).Once() + + c := &model.Channel{ + Id: model.NewId(), + Name: "name", + DisplayName: "displayname", + } + api.On("CreateChannel", c).Return(c, nil).Once() + + notFoundErr := model.NewAppError("", "", nil, "", http.StatusNotFound) + api.On("GetChannel", c.Id).Return(c, notFoundErr).Times(3) + + otherErr := model.NewAppError("", "", nil, "", http.StatusInternalServerError) + api.On("GetChannel", c.Id).Return(c, otherErr).Times(1) + + err := client.Channel.Create(c) + require.Error(t, err) + }) + + t.Run("create channel, give up waiting", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + config := &model.Config{ + SqlSettings: model.SqlSettings{ + DataSourceReplicas: []string{"replica1"}, + }, + } + api.On("GetConfig").Return(config).Once() + + c := &model.Channel{ + Id: model.NewId(), + Name: "name", + DisplayName: "displayname", + } + api.On("CreateChannel", c).Return(c, nil).Once() + + notFoundErr := model.NewAppError("", "", nil, "", http.StatusNotFound) + api.On("GetChannel", c.Id).Return(c, notFoundErr) + + err := client.Channel.Create(c) + require.Error(t, err) + require.Contains(t, err.Error(), "giving up waiting") + }) +} + +func TestCreateSidebarCategory(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + category := model.SidebarCategoryWithChannels{} + + api.On("CreateChannelSidebarCategory", "user_id", "team_id", &category). + Return(&model.SidebarCategoryWithChannels{ + SidebarCategory: model.SidebarCategory{ + Id: "id", + UserId: "user_id", + TeamId: "team_id", + }, + Channels: []string{"channelA", "channelB"}}, + nil) + + err := client.Channel.CreateSidebarCategory("user_id", "team_id", &category) + + require.NoError(t, err) + require.Equal(t, + model.SidebarCategoryWithChannels{ + SidebarCategory: model.SidebarCategory{Id: "id", UserId: "user_id", TeamId: "team_id"}, + Channels: []string{"channelA", "channelB"}}, + category) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + inputCategory := model.SidebarCategoryWithChannels{} + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("CreateChannelSidebarCategory", "user_id", "team_id", &inputCategory). + Return(&model.SidebarCategoryWithChannels{}, appErr) + + err := client.Channel.CreateSidebarCategory("user_id", "team_id", &inputCategory) + + require.Equal(t, appErr, err) + }) +} + +func TestGetSidebarCategories(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetChannelSidebarCategories", "user_id", "team_id"). + Return(&model.OrderedSidebarCategories{ + Categories: nil, + Order: []string{"channelA", "channelB"}, + }, + nil) + + categories, err := client.Channel.GetSidebarCategories("user_id", "team_id") + + require.NoError(t, err) + require.Equal(t, + model.OrderedSidebarCategories{ + Categories: nil, + Order: []string{"channelA", "channelB"}, + }, + *categories) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("GetChannelSidebarCategories", "user_id", "team_id").Return(nil, appErr) + + _, err := client.Channel.GetSidebarCategories("user_id", "team_id") + + require.Equal(t, appErr, err) + }) +} + +func TestUpdateSidebarCategories(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + categories := []*model.SidebarCategoryWithChannels{ + { + SidebarCategory: model.SidebarCategory{}, + Channels: nil, + }, + } + updatedCategories := []*model.SidebarCategoryWithChannels{ + { + SidebarCategory: model.SidebarCategory{ + Id: "id", + UserId: "user_id", + TeamId: "team_id", + }, + Channels: []string{"channelA", "channelB"}, + }} + + api.On("UpdateChannelSidebarCategories", "user_id", "team_id", categories).Return(updatedCategories, nil) + + err := client.Channel.UpdateSidebarCategories("user_id", "team_id", categories) + + require.NoError(t, err) + require.EqualValues(t, updatedCategories, categories) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + inputCategories := []*model.SidebarCategoryWithChannels{ + { + SidebarCategory: model.SidebarCategory{}, + Channels: nil, + }, + } + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("UpdateChannelSidebarCategories", "user_id", "team_id", inputCategories).Return(nil, appErr) + + err := client.Channel.UpdateSidebarCategories("user_id", "team_id", inputCategories) + + require.Equal(t, appErr, err) + }) +} diff --git a/server/public/pluginapi/client.go b/server/public/pluginapi/client.go new file mode 100644 index 0000000000..e917ecb8c7 --- /dev/null +++ b/server/public/pluginapi/client.go @@ -0,0 +1,79 @@ +package pluginapi + +import ( + "github.com/blang/semver/v4" + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/pkg/errors" +) + +// Client is a streamlined wrapper over the mattermost plugin API. +type Client struct { + api plugin.API + + Bot BotService + Channel ChannelService + Cluster ClusterService + Configuration ConfigurationService + SlashCommand SlashCommandService + OAuth OAuthService + Emoji EmojiService + File FileService + Frontend FrontendService + Group GroupService + KV KVService + Log LogService + Mail MailService + Plugin PluginService + Post PostService + Session SessionService + Store *StoreService + System SystemService + Team TeamService + User UserService +} + +// NewClient creates a new instance of Client. +// +// This client must only be created once per plugin to +// prevent reacquiring of resources. +func NewClient(api plugin.API, driver plugin.Driver) *Client { + return &Client{ + api: api, + + Bot: BotService{api: api}, + Channel: ChannelService{api: api}, + Cluster: ClusterService{api: api}, + Configuration: ConfigurationService{api: api}, + SlashCommand: SlashCommandService{api: api}, + OAuth: OAuthService{api: api}, + Emoji: EmojiService{api: api}, + File: FileService{api: api}, + Frontend: FrontendService{api: api}, + Group: GroupService{api: api}, + KV: KVService{api: api}, + Log: LogService{api: api}, + Mail: MailService{api: api}, + Plugin: PluginService{api: api}, + Post: PostService{api: api}, + Session: SessionService{api: api}, + Store: &StoreService{ + api: api, + driver: driver, + }, + System: SystemService{api: api}, + Team: TeamService{api: api}, + User: UserService{api: api}, + } +} + +func ensureServerVersion(api plugin.API, required string) error { + serverVersion := api.GetServerVersion() + currentVersion := semver.MustParse(serverVersion) + requiredVersion := semver.MustParse(required) + + if currentVersion.LT(requiredVersion) { + return errors.Errorf("incompatible server version for plugin, minimum required version: %s, current version: %s", required, serverVersion) + } + + return nil +} diff --git a/server/public/pluginapi/cluster.go b/server/public/pluginapi/cluster.go new file mode 100644 index 0000000000..23af8d7a32 --- /dev/null +++ b/server/public/pluginapi/cluster.go @@ -0,0 +1,23 @@ +package pluginapi + +import ( + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" +) + +// ClusterService exposes methods to interact with cluster nodes. +type ClusterService struct { + api plugin.API +} + +// ClusterService broadcasts a plugin event to all other running instances of +// the calling plugin that are present in the cluster. +// +// This method is used to allow plugin communication in a High-Availability cluster. +// The receiving side should implement the OnPluginClusterEvent hook +// to receive events sent through this method. +// +// Minimum server version: 5.36 +func (c *ClusterService) PublishPluginEvent(ev model.PluginClusterEvent, opts model.PluginClusterEventSendOptions) error { + return c.api.PublishPluginClusterEvent(ev, opts) +} diff --git a/server/public/pluginapi/cluster/doc.go b/server/public/pluginapi/cluster/doc.go new file mode 100644 index 0000000000..2024d02b27 --- /dev/null +++ b/server/public/pluginapi/cluster/doc.go @@ -0,0 +1,3 @@ +// package cluster exposes synchronization primitives to ensure correct behavior across multiple +// plugin instances in a Mattermost cluster. +package cluster diff --git a/server/public/pluginapi/cluster/job.go b/server/public/pluginapi/cluster/job.go new file mode 100644 index 0000000000..b4e43b29e2 --- /dev/null +++ b/server/public/pluginapi/cluster/job.go @@ -0,0 +1,229 @@ +package cluster + +import ( + "encoding/json" + "sync" + "time" + + "github.com/pkg/errors" + + "github.com/mattermost/mattermost/server/public/model" +) + +const ( + // cronPrefix is used to namespace key values created for a job from other key values + // created by a plugin. + cronPrefix = "cron_" +) + +// JobPluginAPI is the plugin API interface required to schedule jobs. +type JobPluginAPI interface { + MutexPluginAPI + KVGet(key string) ([]byte, *model.AppError) + KVDelete(key string) *model.AppError + KVList(page, count int) ([]string, *model.AppError) +} + +// JobConfig defines the configuration of a scheduled job. +type JobConfig struct { + // Interval is the period of execution for the job. + Interval time.Duration +} + +// NextWaitInterval is a callback computing the next wait interval for a job. +type NextWaitInterval func(now time.Time, metadata JobMetadata) time.Duration + +// MakeWaitForInterval creates a function to scheduling a job to run on the given interval relative +// to the last finished timestamp. +// +// For example, if the job first starts at 12:01 PM, and is configured with interval 5 minutes, +// it will next run at: +// +// 12:06, 12:11, 12:16, ... +// +// If the job has not previously started, it will run immediately. +func MakeWaitForInterval(interval time.Duration) NextWaitInterval { + if interval == 0 { + panic("must specify non-zero ready interval") + } + + return func(now time.Time, metadata JobMetadata) time.Duration { + sinceLastFinished := now.Sub(metadata.LastFinished) + if sinceLastFinished < interval { + return interval - sinceLastFinished + } + + return 0 + } +} + +// MakeWaitForRoundedInterval creates a function, scheduling a job to run on the nearest rounded +// interval relative to the last finished timestamp. +// +// For example, if the job first starts at 12:04 PM, and is configured with interval 5 minutes, +// and is configured to round to 5 minute intervals, it will next run at: +// +// 12:05 PM, 12:10 PM, 12:15 PM, ... +// +// If the job has not previously started, it will run immediately. Note that this wait interval +// strategy does not guarantee a minimum interval between runs, only that subsequent runs will be +// scheduled on the rounded interval. +func MakeWaitForRoundedInterval(interval time.Duration) NextWaitInterval { + if interval == 0 { + panic("must specify non-zero ready interval") + } + + return func(now time.Time, metadata JobMetadata) time.Duration { + if metadata.LastFinished.IsZero() { + return 0 + } + + target := metadata.LastFinished.Add(interval).Truncate(interval) + untilTarget := target.Sub(now) + if untilTarget > 0 { + return untilTarget + } + + return 0 + } +} + +// Job is a scheduled job whose callback function is executed on a configured interval by at most +// one plugin instance at a time. +// +// Use scheduled jobs to perform background activity on a regular interval without having to +// explicitly coordinate with other instances of the same plugin that might repeat that effort. +type Job struct { + pluginAPI JobPluginAPI + key string + mutex *Mutex + nextWaitInterval NextWaitInterval + callback func() + + stopOnce sync.Once + stop chan bool + done chan bool +} + +// JobMetadata persists metadata about job execution. +type JobMetadata struct { + // LastFinished is the last time the job finished anywhere in the cluster. + LastFinished time.Time +} + +// Schedule creates a scheduled job. +func Schedule(pluginAPI JobPluginAPI, key string, nextWaitInterval NextWaitInterval, callback func()) (*Job, error) { + key = cronPrefix + key + + mutex, err := NewMutex(pluginAPI, key) + if err != nil { + return nil, errors.Wrap(err, "failed to create job mutex") + } + + job := &Job{ + pluginAPI: pluginAPI, + key: key, + mutex: mutex, + nextWaitInterval: nextWaitInterval, + callback: callback, + stop: make(chan bool), + done: make(chan bool), + } + + go job.run() + + return job, nil +} + +// readMetadata reads the job execution metadata from the kv store. +func (j *Job) readMetadata() (JobMetadata, error) { + data, appErr := j.pluginAPI.KVGet(j.key) + if appErr != nil { + return JobMetadata{}, errors.Wrap(appErr, "failed to read data") + } + + if data == nil { + return JobMetadata{}, nil + } + + var metadata JobMetadata + err := json.Unmarshal(data, &metadata) + if err != nil { + return JobMetadata{}, errors.Wrap(err, "failed to decode data") + } + + return metadata, nil +} + +// saveMetadata writes updated job execution metadata from the kv store. +// +// It is assumed that the job mutex is held, negating the need to require an atomic write. +func (j *Job) saveMetadata(metadata JobMetadata) error { + data, err := json.Marshal(metadata) + if err != nil { + return errors.Wrap(err, "failed to marshal data") + } + + ok, appErr := j.pluginAPI.KVSetWithOptions(j.key, data, model.PluginKVSetOptions{}) + if appErr != nil || !ok { + return errors.Wrap(appErr, "failed to set data") + } + + return nil +} + +// run attempts to run the scheduled job, guaranteeing only one instance is executing concurrently. +func (j *Job) run() { + defer close(j.done) + + var waitInterval time.Duration + + for { + select { + case <-j.stop: + return + case <-time.After(waitInterval): + } + + func() { + // Acquire the corresponding job lock and hold it throughout execution. + j.mutex.Lock() + defer j.mutex.Unlock() + + metadata, err := j.readMetadata() + if err != nil { + j.pluginAPI.LogError("failed to read job metadata", "err", err, "key", j.key) + waitInterval = nextWaitInterval(waitInterval, err) + return + } + + // Is it time to run the job? + waitInterval = j.nextWaitInterval(time.Now(), metadata) + if waitInterval > 0 { + return + } + + // Run the job + j.callback() + + metadata.LastFinished = time.Now() + + err = j.saveMetadata(metadata) + if err != nil { + j.pluginAPI.LogError("failed to write job data", "err", err, "key", j.key) + } + + waitInterval = j.nextWaitInterval(time.Now(), metadata) + }() + } +} + +// Close terminates a scheduled job, preventing it from being scheduled on this plugin instance. +func (j *Job) Close() error { + j.stopOnce.Do(func() { + close(j.stop) + }) + <-j.done + + return nil +} diff --git a/server/public/pluginapi/cluster/job_example_test.go b/server/public/pluginapi/cluster/job_example_test.go new file mode 100644 index 0000000000..3d5235c39d --- /dev/null +++ b/server/public/pluginapi/cluster/job_example_test.go @@ -0,0 +1,25 @@ +package cluster + +import ( + "time" + + "github.com/mattermost/mattermost/server/public/plugin" +) + +func ExampleSchedule() { + // Use p.API from your plugin instead. + pluginAPI := plugin.API(nil) + + callback := func() { + // periodic work to do + } + + job, err := Schedule(pluginAPI, "key", MakeWaitForInterval(5*time.Minute), callback) + if err != nil { + panic("failed to schedule job") + } + + // main thread + + defer job.Close() +} diff --git a/server/public/pluginapi/cluster/job_once.go b/server/public/pluginapi/cluster/job_once.go new file mode 100644 index 0000000000..7b05bd866a --- /dev/null +++ b/server/public/pluginapi/cluster/job_once.go @@ -0,0 +1,235 @@ +package cluster + +import ( + "encoding/json" + "math/rand" + "sync" + "time" + + "github.com/pkg/errors" + + "github.com/mattermost/mattermost/server/public/model" +) + +const ( + // oncePrefix is used to namespace key values created for a scheduleOnce job + oncePrefix = "once_" + + // keysPerPage is the maximum number of keys to retrieve from the db per call + keysPerPage = 1000 + + // maxNumFails is the maximum number of KVStore read fails or failed attempts to run the + // callback until the scheduler cancels a job. + maxNumFails = 3 + + // waitAfterFail is the amount of time to wait after a failure + waitAfterFail = 1 * time.Second + + // pollNewJobsInterval is the amount of time to wait between polling the db for new scheduled jobs + pollNewJobsInterval = 5 * time.Minute + + // scheduleOnceJitter is the range of jitter to add to intervals to avoid contention issues + scheduleOnceJitter = 100 * time.Millisecond + + // propsLimit is the maximum length in bytes of the json-representation of a job's props. + // It exists to prevent job go rountines from consuming too much memory, as they are long running. + propsLimit = 10000 +) + +type JobOnceMetadata struct { + Key string + RunAt time.Time + Props any +} + +type JobOnce struct { + pluginAPI JobPluginAPI + clusterMutex *Mutex + + // key is the original key. It is prefixed with oncePrefix when used as a key in the KVStore + key string + props any + runAt time.Time + numFails int + + // done signals the job.run go routine to exit + done chan bool + doneOnce sync.Once + + // join is a join point for the job.run() goroutine to join the calling goroutine (in this case, + // the one calling job.Cancel) + join chan bool + joinOnce sync.Once + + storedCallback *syncedCallback + activeJobs *syncedJobs +} + +// Cancel terminates a scheduled job, preventing it from being scheduled on this plugin instance. +// It also removes the job from the db, preventing it from being run in the future. +func (j *JobOnce) Cancel() { + j.clusterMutex.Lock() + defer j.clusterMutex.Unlock() + + j.cancelWhileHoldingMutex() + + // join the running goroutine + j.joinOnce.Do(func() { + <-j.join + }) +} + +func newJobOnce(pluginAPI JobPluginAPI, key string, runAt time.Time, callback *syncedCallback, jobs *syncedJobs, props any) (*JobOnce, error) { + mutex, err := NewMutex(pluginAPI, key) + if err != nil { + return nil, errors.Wrap(err, "failed to create job mutex") + } + + propsBytes, err := json.Marshal(props) + if err != nil { + return nil, errors.Wrap(err, "failed to marshal props") + } + + if len(propsBytes) > propsLimit { + return nil, errors.Errorf("props length extends limit") + } + + return &JobOnce{ + pluginAPI: pluginAPI, + clusterMutex: mutex, + key: key, + props: props, + runAt: runAt, + done: make(chan bool), + join: make(chan bool), + storedCallback: callback, + activeJobs: jobs, + }, nil +} + +func (j *JobOnce) run() { + defer close(j.join) + + wait := time.Until(j.runAt) + + for { + select { + case <-j.done: + return + case <-time.After(wait + addJitter()): + } + + func() { + // Acquire the cluster mutex while we're trying to do the job + j.clusterMutex.Lock() + defer j.clusterMutex.Unlock() + + // Check that the job has not been completed + metadata, err := readMetadata(j.pluginAPI, j.key) + if err != nil { + j.numFails++ + if j.numFails > maxNumFails { + j.cancelWhileHoldingMutex() + return + } + + // wait a bit of time and try again + wait = waitAfterFail + return + } + + // If key doesn't exist, or if the runAt has changed, the original job has been completed already + if metadata == nil || !j.runAt.Equal(metadata.RunAt) { + j.cancelWhileHoldingMutex() + return + } + + j.executeJob() + + j.cancelWhileHoldingMutex() + }() + } +} + +func (j *JobOnce) executeJob() { + j.storedCallback.mu.Lock() + defer j.storedCallback.mu.Unlock() + + j.storedCallback.callback(j.key, j.props) +} + +// readMetadata reads the job's stored metadata. If the caller wishes to make an atomic +// read/write, the cluster mutex for job's key should be held. +func readMetadata(pluginAPI JobPluginAPI, key string) (*JobOnceMetadata, error) { + data, appErr := pluginAPI.KVGet(oncePrefix + key) + if appErr != nil { + return nil, errors.Wrap(normalizeAppErr(appErr), "failed to read data") + } + + if data == nil { + return nil, nil + } + + var metadata JobOnceMetadata + if err := json.Unmarshal(data, &metadata); err != nil { + return nil, errors.Wrap(err, "failed to decode data") + } + + return &metadata, nil +} + +// saveMetadata writes the job's metadata to the kvstore. saveMetadata acquires the job's cluster lock. +// saveMetadata will not overwrite an existing key. +func (j *JobOnce) saveMetadata() error { + j.clusterMutex.Lock() + defer j.clusterMutex.Unlock() + + metadata := JobOnceMetadata{ + Key: j.key, + Props: j.props, + RunAt: j.runAt, + } + data, err := json.Marshal(metadata) + if err != nil { + return errors.Wrap(err, "failed to marshal data") + } + + ok, appErr := j.pluginAPI.KVSetWithOptions(oncePrefix+j.key, data, model.PluginKVSetOptions{ + Atomic: true, + OldValue: nil, + }) + if appErr != nil { + return normalizeAppErr(appErr) + } + if !ok { + return errors.New("failed to set data") + } + + return nil +} + +// cancelWhileHoldingMutex assumes the caller holds the job's mutex. +func (j *JobOnce) cancelWhileHoldingMutex() { + // remove the job from the kv store, if it exists + _ = j.pluginAPI.KVDelete(oncePrefix + j.key) + + j.activeJobs.mu.Lock() + defer j.activeJobs.mu.Unlock() + delete(j.activeJobs.jobs, j.key) + + j.doneOnce.Do(func() { + close(j.done) + }) +} + +func addJitter() time.Duration { + return time.Duration(rand.Int63n(int64(scheduleOnceJitter))) +} + +func normalizeAppErr(appErr *model.AppError) error { + if appErr == nil { + return nil + } + + return appErr +} diff --git a/server/public/pluginapi/cluster/job_once_example_test.go b/server/public/pluginapi/cluster/job_once_example_test.go new file mode 100644 index 0000000000..cd058bf62a --- /dev/null +++ b/server/public/pluginapi/cluster/job_once_example_test.go @@ -0,0 +1,45 @@ +package cluster + +import ( + "log" + "time" + + "github.com/mattermost/mattermost/server/public/plugin" +) + +func HandleJobOnceCalls(key string, props any) { + if key == "the key i'm watching for" { + log.Println(props) + // Work to do only once per cluster + } +} + +func ExampleJobOnceScheduler_ScheduleOnce() { + // Use p.API from your plugin instead. + pluginAPI := plugin.API(nil) + + // Get the scheduler, which you can pass throughout the plugin... + scheduler := GetJobOnceScheduler(pluginAPI) + + // Set the plugin's callback handler + _ = scheduler.SetCallback(HandleJobOnceCalls) + + // Now start the scheduler, which starts the poller and schedules all waiting jobs. + _ = scheduler.Start() + + // main thread... + + // add a job + _, _ = scheduler.ScheduleOnce("the key i'm watching for", time.Now().Add(2*time.Hour), struct{ foo string }{"aasd"}) + + // Maybe you want to check the scheduled jobs, or cancel them. This is completely optional--there + // is no need to cancel jobs, even if you are shutting down. Call Cancel only when you want to + // cancel a future job. Cancelling a job will prevent it from running in the future on this or + // any server. + jobs, _ := scheduler.ListScheduledJobs() + defer func() { + for _, j := range jobs { + scheduler.Cancel(j.Key) + } + }() +} diff --git a/server/public/pluginapi/cluster/job_once_mem_test.go b/server/public/pluginapi/cluster/job_once_mem_test.go new file mode 100644 index 0000000000..2827ea26cd --- /dev/null +++ b/server/public/pluginapi/cluster/job_once_mem_test.go @@ -0,0 +1,83 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package cluster + +import ( + "fmt" + "runtime" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" +) + +func TestMemFootprint(t *testing.T) { + var memConsumed = func() uint64 { + runtime.GC() + var s runtime.MemStats + runtime.ReadMemStats(&s) + return s.Sys + } + + t.Run("average k per jobOnce", func(t *testing.T) { + t.SkipNow() + + makeKey := model.NewId + + numJobs := 100000 + jobs := make(map[string]*int32, numJobs) + for i := 0; i < numJobs; i++ { + jobs[makeKey()] = new(int32) + } + + callback := func(key string, _ any) { + count, ok := jobs[key] + if ok { + atomic.AddInt32(count, 1) + } + } + + mockPluginAPI := newMockPluginAPI(t) + s := GetJobOnceScheduler(mockPluginAPI) + err := s.SetCallback(callback) + require.NoError(t, err) + err = s.Start() + require.NoError(t, err) + + getVal := func(key string) []byte { + data, _ := s.pluginAPI.KVGet(key) + return data + } + + before := memConsumed() + + for k := range jobs { + assert.Empty(t, getVal(oncePrefix+k)) + _, err = s.ScheduleOnce(k, time.Now().Add(5*time.Minute), nil) + require.NoError(t, err) + assert.NotEmpty(t, getVal(oncePrefix+k)) + } + + time.Sleep(10 * time.Second) + + // Everything scheduled now: + s.activeJobs.mu.RLock() + assert.Equal(t, numJobs, len(s.activeJobs.jobs)) + s.activeJobs.mu.RUnlock() + list, err := s.ListScheduledJobs() + require.NoError(t, err) + assert.Equal(t, numJobs, len(list)) + + after := memConsumed() + + fmt.Printf("\nthe %d jobs, scheduler, and goroutines require: %.2fmB memory, or %.3fkB each job\n", + numJobs, + float64(after-before)/(1024*1024), + (float64(after-before)/float64(numJobs))/1024) + }) +} diff --git a/server/public/pluginapi/cluster/job_once_scheduler.go b/server/public/pluginapi/cluster/job_once_scheduler.go new file mode 100644 index 0000000000..ca314397ff --- /dev/null +++ b/server/public/pluginapi/cluster/job_once_scheduler.go @@ -0,0 +1,236 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package cluster + +import ( + "strings" + "sync" + "time" + + "github.com/pkg/errors" +) + +// syncedCallback uses the mutex to make things predictable for the client: the callback will be +// called once at a time (the client does not need to worry about concurrency within the callback) +type syncedCallback struct { + mu sync.Mutex + callback func(string, any) +} + +type syncedJobs struct { + mu sync.RWMutex + jobs map[string]*JobOnce +} + +type JobOnceScheduler struct { + pluginAPI JobPluginAPI + + startedMu sync.RWMutex + started bool + + activeJobs *syncedJobs + storedCallback *syncedCallback +} + +var schedulerOnce sync.Once +var s *JobOnceScheduler + +// GetJobOnceScheduler returns a scheduler which is ready to have its callback set. Repeated +// calls will return the same scheduler. +func GetJobOnceScheduler(pluginAPI JobPluginAPI) *JobOnceScheduler { + schedulerOnce.Do(func() { + s = &JobOnceScheduler{ + pluginAPI: pluginAPI, + activeJobs: &syncedJobs{ + jobs: make(map[string]*JobOnce), + }, + storedCallback: &syncedCallback{}, + } + }) + return s +} + +// Start starts the Scheduler. It finds all previous ScheduleOnce jobs and starts them running, and +// fires any jobs that have reached or exceeded their runAt time. Thus, even if a cluster goes down +// and is restarted, Start will restart previously scheduled jobs. +func (s *JobOnceScheduler) Start() error { + s.startedMu.Lock() + defer s.startedMu.Unlock() + if s.started { + return errors.New("scheduler has already been started") + } + + if err := s.verifyCallbackExists(); err != nil { + return errors.Wrap(err, "callback not found; cannot start scheduler") + } + + if err := s.scheduleNewJobsFromDB(); err != nil { + return errors.Wrap(err, "could not start JobOnceScheduler due to error") + } + + go s.pollForNewScheduledJobs() + + s.started = true + + return nil +} + +// SetCallback sets the scheduler's callback. When a job fires, the callback will be called with +// the job's id. +func (s *JobOnceScheduler) SetCallback(callback func(string, any)) error { + if callback == nil { + return errors.New("callback cannot be nil") + } + + s.storedCallback.mu.Lock() + defer s.storedCallback.mu.Unlock() + + s.storedCallback.callback = callback + return nil +} + +// ListScheduledJobs returns a list of the jobs in the db that have been scheduled. There is no +// guarantee that list is accurate by the time the caller reads the list. E.g., the jobs in the list +// may have been run, canceled, or new jobs may have scheduled. +func (s *JobOnceScheduler) ListScheduledJobs() ([]JobOnceMetadata, error) { + var ret []JobOnceMetadata + for i := 0; ; i++ { + keys, err := s.pluginAPI.KVList(i, keysPerPage) + if err != nil { + return nil, errors.Wrap(err, "error getting KVList") + } + for _, k := range keys { + if strings.HasPrefix(k, oncePrefix) { + metadata, err := readMetadata(s.pluginAPI, k[len(oncePrefix):]) + if err != nil { + s.pluginAPI.LogError(errors.Wrap(err, "could not retrieve data from plugin kvstore for key: "+k).Error()) + continue + } + if metadata == nil { + continue + } + + ret = append(ret, *metadata) + } + } + + if len(keys) < keysPerPage { + break + } + } + + return ret, nil +} + +// ScheduleOnce creates a scheduled job that will run once. When the clock reaches runAt, the +// callback will be called with key and props as the argument. +// +// If the job key already exists in the db, this will return an error. To reschedule a job, first +// cancel the original then schedule it again. +func (s *JobOnceScheduler) ScheduleOnce(key string, runAt time.Time, props any) (*JobOnce, error) { + s.startedMu.RLock() + defer s.startedMu.RUnlock() + if !s.started { + return nil, errors.New("start the scheduler before adding jobs") + } + + job, err := newJobOnce(s.pluginAPI, key, runAt, s.storedCallback, s.activeJobs, props) + if err != nil { + return nil, errors.Wrap(err, "could not create new job") + } + + if err = job.saveMetadata(); err != nil { + return nil, errors.Wrap(err, "could not save job metadata") + } + + s.runAndTrack(job) + + return job, nil +} + +// Cancel cancels a job by its key. This is useful if the plugin lost the original *JobOnce, or +// is stopping a job found in ListScheduledJobs(). +func (s *JobOnceScheduler) Cancel(key string) { + // using an anonymous function because job.Close() below needs access to the activeJobs mutex + job := func() *JobOnce { + s.activeJobs.mu.RLock() + defer s.activeJobs.mu.RUnlock() + j, ok := s.activeJobs.jobs[key] + if ok { + return j + } + + // Job wasn't active, so no need to call CancelWhileHoldingMutex (which shuts down the + // goroutine). There's a condition where another server in the cluster started the job, and + // the current server hasn't polled for it yet. To solve that case, delete it from the db. + mutex, err := NewMutex(s.pluginAPI, key) + if err != nil { + s.pluginAPI.LogError(errors.Wrap(err, "failed to create job mutex in Cancel for key: "+key).Error()) + } + mutex.Lock() + defer mutex.Unlock() + + _ = s.pluginAPI.KVDelete(oncePrefix + key) + + return nil + }() + + if job != nil { + job.Cancel() + } +} + +func (s *JobOnceScheduler) scheduleNewJobsFromDB() error { + scheduled, err := s.ListScheduledJobs() + if err != nil { + return errors.Wrap(err, "could not read scheduled jobs from db") + } + + for _, m := range scheduled { + job, err := newJobOnce(s.pluginAPI, m.Key, m.RunAt, s.storedCallback, s.activeJobs, m.Props) + if err != nil { + s.pluginAPI.LogError(errors.Wrap(err, "could not create new job for key: "+m.Key).Error()) + continue + } + + s.runAndTrack(job) + } + + return nil +} + +func (s *JobOnceScheduler) runAndTrack(job *JobOnce) { + s.activeJobs.mu.Lock() + defer s.activeJobs.mu.Unlock() + + // has this been scheduled already on this server? + if _, ok := s.activeJobs.jobs[job.key]; ok { + return + } + + go job.run() + + s.activeJobs.jobs[job.key] = job +} + +// pollForNewScheduledJobs will only be started once per plugin. It doesn't need to be stopped. +func (s *JobOnceScheduler) pollForNewScheduledJobs() { + for { + <-time.After(pollNewJobsInterval + addJitter()) + + if err := s.scheduleNewJobsFromDB(); err != nil { + s.pluginAPI.LogError("pluginAPI scheduleOnce poller encountered an error but is still polling", "error", err) + } + } +} + +func (s *JobOnceScheduler) verifyCallbackExists() error { + s.storedCallback.mu.Lock() + defer s.storedCallback.mu.Unlock() + + if s.storedCallback.callback == nil { + return errors.New("set callback before starting the scheduler") + } + return nil +} diff --git a/server/public/pluginapi/cluster/job_once_test.go b/server/public/pluginapi/cluster/job_once_test.go new file mode 100644 index 0000000000..85a2d0f8f8 --- /dev/null +++ b/server/public/pluginapi/cluster/job_once_test.go @@ -0,0 +1,685 @@ +package cluster + +import ( + "encoding/json" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestScheduleOnceParallel(t *testing.T) { + makeKey := model.NewId + + // there is only one callback by design, so all tests need to add their key + // and callback handling code here. + jobKey1 := makeKey() + count1 := new(int32) + jobKey2 := makeKey() + count2 := new(int32) + jobKey3 := makeKey() + jobKey4 := makeKey() + count4 := new(int32) + jobKey5 := makeKey() + count5 := new(int32) + + manyJobs := make(map[string]*int32) + for i := 0; i < 100; i++ { + manyJobs[makeKey()] = new(int32) + } + + callback := func(key string, _ any) { + switch key { + case jobKey1: + atomic.AddInt32(count1, 1) + case jobKey2: + atomic.AddInt32(count2, 1) + case jobKey3: + return // do nothing, like an error occurred in the plugin + case jobKey4: + atomic.AddInt32(count4, 1) + case jobKey5: + atomic.AddInt32(count5, 1) + default: + count, ok := manyJobs[key] + if ok { + atomic.AddInt32(count, 1) + return + } + } + } + + mockPluginAPI := newMockPluginAPI(t) + getVal := func(key string) []byte { + data, _ := mockPluginAPI.KVGet(key) + return data + } + + s := GetJobOnceScheduler(mockPluginAPI) + + // should error if we try to start without callback + err := s.Start() + require.Error(t, err) + + err = s.SetCallback(callback) + require.NoError(t, err) + err = s.Start() + require.NoError(t, err) + + jobs, err := s.ListScheduledJobs() + require.NoError(t, err) + require.Empty(t, jobs) + + t.Run("one scheduled job", func(t *testing.T) { + t.Parallel() + + job, err2 := s.ScheduleOnce(jobKey1, time.Now().Add(100*time.Millisecond), nil) + require.NoError(t, err2) + require.NotNil(t, job) + assert.NotEmpty(t, getVal(oncePrefix+jobKey1)) + + time.Sleep(200*time.Millisecond + scheduleOnceJitter) + + assert.Empty(t, getVal(oncePrefix+jobKey1)) + s.activeJobs.mu.RLock() + assert.Empty(t, s.activeJobs.jobs[jobKey1]) + s.activeJobs.mu.RUnlock() + + // It's okay to cancel jobs extra times, even if they're completed. + job.Cancel() + job.Cancel() + job.Cancel() + job.Cancel() + + // Should have been called once + assert.Equal(t, int32(1), atomic.LoadInt32(count1)) + }) + + t.Run("one job, stopped before firing", func(t *testing.T) { + t.Parallel() + + job, err2 := s.ScheduleOnce(jobKey2, time.Now().Add(100*time.Millisecond), nil) + require.NoError(t, err2) + require.NotNil(t, job) + assert.NotEmpty(t, getVal(oncePrefix+jobKey2)) + + job.Cancel() + assert.Empty(t, getVal(oncePrefix+jobKey2)) + s.activeJobs.mu.RLock() + assert.Empty(t, s.activeJobs.jobs[jobKey2]) + s.activeJobs.mu.RUnlock() + + time.Sleep(2 * (waitAfterFail + scheduleOnceJitter)) + + // Should not have been called + assert.Equal(t, int32(0), atomic.LoadInt32(count2)) + + // It's okay to cancel jobs extra times, even if they're completed. + job.Cancel() + job.Cancel() + job.Cancel() + job.Cancel() + }) + + t.Run("failed at the plugin, job removed from db", func(t *testing.T) { + t.Parallel() + + job, err2 := s.ScheduleOnce(jobKey3, time.Now().Add(100*time.Millisecond), nil) + require.NoError(t, err2) + require.NotNil(t, job) + assert.NotEmpty(t, getVal(oncePrefix+jobKey3)) + + time.Sleep(200*time.Millisecond + scheduleOnceJitter) + assert.Empty(t, getVal(oncePrefix+jobKey3)) + s.activeJobs.mu.RLock() + assert.Empty(t, s.activeJobs.jobs[jobKey3]) + s.activeJobs.mu.RUnlock() + }) + + t.Run("cancel and restart a job with the same key", func(t *testing.T) { + t.Parallel() + + job, err2 := s.ScheduleOnce(jobKey4, time.Now().Add(100*time.Millisecond), nil) + require.NoError(t, err2) + require.NotNil(t, job) + assert.NotEmpty(t, getVal(oncePrefix+jobKey4)) + + job.Cancel() + assert.Empty(t, getVal(oncePrefix+jobKey4)) + s.activeJobs.mu.RLock() + assert.Empty(t, s.activeJobs.jobs[jobKey4]) + s.activeJobs.mu.RUnlock() + + job, err2 = s.ScheduleOnce(jobKey4, time.Now().Add(100*time.Millisecond), nil) + require.NoError(t, err2) + require.NotNil(t, job) + assert.NotEmpty(t, getVal(oncePrefix+jobKey4)) + + time.Sleep(200*time.Millisecond + scheduleOnceJitter) + assert.Equal(t, int32(1), atomic.LoadInt32(count4)) + assert.Empty(t, getVal(oncePrefix+jobKey4)) + s.activeJobs.mu.RLock() + assert.Empty(t, s.activeJobs.jobs[jobKey4]) + s.activeJobs.mu.RUnlock() + }) + + t.Run("many scheduled jobs", func(t *testing.T) { + t.Parallel() + + for k := range manyJobs { + job, err2 := s.ScheduleOnce(k, time.Now().Add(100*time.Millisecond), nil) + require.NoError(t, err2) + require.NotNil(t, job) + assert.NotEmpty(t, getVal(oncePrefix+k)) + } + + time.Sleep(200*time.Millisecond + scheduleOnceJitter) + + for k, v := range manyJobs { + assert.Empty(t, getVal(oncePrefix+k)) + s.activeJobs.mu.RLock() + assert.Empty(t, s.activeJobs.jobs[k]) + s.activeJobs.mu.RUnlock() + assert.Equal(t, int32(1), *v) + } + }) + + t.Run("cancel a job by key name", func(t *testing.T) { + t.Parallel() + + job, err2 := s.ScheduleOnce(jobKey5, time.Now().Add(100*time.Millisecond), nil) + require.NoError(t, err2) + require.NotNil(t, job) + assert.NotEmpty(t, getVal(oncePrefix+jobKey5)) + s.activeJobs.mu.RLock() + assert.NotEmpty(t, s.activeJobs.jobs[jobKey5]) + s.activeJobs.mu.RUnlock() + + s.Cancel(jobKey5) + + assert.Empty(t, getVal(oncePrefix+jobKey5)) + s.activeJobs.mu.RLock() + assert.Empty(t, s.activeJobs.jobs[jobKey5]) + s.activeJobs.mu.RUnlock() + + // cancel it again doesn't do anything: + s.Cancel(jobKey5) + + time.Sleep(150*time.Millisecond + scheduleOnceJitter) + assert.Equal(t, int32(0), atomic.LoadInt32(count5)) + }) + + t.Run("starting the scheduler again will return an error", func(t *testing.T) { + t.Parallel() + + newScheduler := GetJobOnceScheduler(mockPluginAPI) + err = newScheduler.Start() + require.Error(t, err) + }) +} + +func TestScheduleOnceSequential(t *testing.T) { + makeKey := model.NewId + + // get the existing scheduler + s := GetJobOnceScheduler(newMockPluginAPI(t)) + getVal := func(key string) []byte { + data, _ := s.pluginAPI.KVGet(key) + return data + } + setMetadata := func(key string, metadata JobOnceMetadata) error { + data, err := json.Marshal(metadata) + if err != nil { + return err + } + ok, appErr := s.pluginAPI.KVSetWithOptions(oncePrefix+key, data, model.PluginKVSetOptions{}) + if !ok { + return errors.New("KVSetWithOptions failed") + } + if appErr != nil { + return normalizeAppErr(appErr) + } + return nil + } + + resetScheduler := func() { + s.activeJobs.mu.Lock() + defer s.activeJobs.mu.Unlock() + s.activeJobs.jobs = make(map[string]*JobOnce) + s.storedCallback.mu.Lock() + defer s.storedCallback.mu.Unlock() + s.storedCallback.callback = nil + s.startedMu.Lock() + defer s.startedMu.Unlock() + s.started = false + s.pluginAPI.(*mockPluginAPI).clear() + } + + t.Run("starting the scheduler without a callback will return an error", func(t *testing.T) { + resetScheduler() + + err := s.Start() + require.Error(t, err) + }) + + t.Run("trying to schedule a job without starting will return an error", func(t *testing.T) { + resetScheduler() + + callback := func(key string, _ any) {} + err := s.SetCallback(callback) + require.NoError(t, err) + + _, err = s.ScheduleOnce("will fail", time.Now(), nil) + require.Error(t, err) + }) + + t.Run("adding two callback works, only second one is called", func(t *testing.T) { + resetScheduler() + + newCount2 := new(int32) + newCount3 := new(int32) + + callback2 := func(key string, _ any) { + atomic.AddInt32(newCount2, 1) + } + callback3 := func(key string, _ any) { + atomic.AddInt32(newCount3, 1) + } + + err := s.SetCallback(callback2) + require.NoError(t, err) + err = s.SetCallback(callback3) + require.NoError(t, err) + err = s.Start() + require.NoError(t, err) + + _, err = s.ScheduleOnce("anything", time.Now().Add(50*time.Millisecond), nil) + require.NoError(t, err) + time.Sleep(70*time.Millisecond + scheduleOnceJitter) + assert.Equal(t, int32(0), atomic.LoadInt32(newCount2)) + assert.Equal(t, int32(1), atomic.LoadInt32(newCount3)) + }) + + t.Run("test paging keys from the db by inserting 3 pages of jobs and starting scheduler", func(t *testing.T) { + resetScheduler() + + numPagingJobs := keysPerPage*3 + 2 + testPagingJobs := make(map[string]*int32) + for i := 0; i < numPagingJobs; i++ { + testPagingJobs[makeKey()] = new(int32) + } + + callback := func(key string, _ any) { + count, ok := testPagingJobs[key] + if ok { + atomic.AddInt32(count, 1) + return + } + } + + // add the test paging jobs before starting scheduler + for k := range testPagingJobs { + assert.Empty(t, getVal(oncePrefix+k)) + job, err := newJobOnce(s.pluginAPI, k, time.Now().Add(100*time.Millisecond), s.storedCallback, s.activeJobs, nil) + require.NoError(t, err) + err = job.saveMetadata() + require.NoError(t, err) + assert.NotEmpty(t, getVal(oncePrefix+k)) + } + + jobs, err := s.ListScheduledJobs() + require.NoError(t, err) + assert.Equal(t, len(testPagingJobs), len(jobs)) + + err = s.SetCallback(callback) + require.NoError(t, err) + + // reschedule from the db: + err = s.scheduleNewJobsFromDB() + require.NoError(t, err) + + // wait for the testPagingJobs created in the setup to finish + time.Sleep(300 * time.Millisecond) + + numInDB := 0 + numActive := 0 + numCountsAtZero := 0 + for k, v := range testPagingJobs { + if getVal(oncePrefix+k) != nil { + numInDB++ + } + s.activeJobs.mu.RLock() + if s.activeJobs.jobs[k] != nil { + numActive++ + } + s.activeJobs.mu.RUnlock() + if atomic.LoadInt32(v) == int32(0) { + numCountsAtZero++ + } + } + + assert.Equal(t, 0, numInDB) + assert.Equal(t, 0, numActive) + assert.Equal(t, 0, numCountsAtZero) + }) + + t.Run("failed at the db", func(t *testing.T) { + resetScheduler() + + jobKey1 := makeKey() + count1 := new(int32) + + callback := func(key string, _ any) { + if key == jobKey1 { + atomic.AddInt32(count1, 1) + } + } + + err := s.SetCallback(callback) + require.NoError(t, err) + err = s.Start() + require.NoError(t, err) + + jobs, err := s.ListScheduledJobs() + require.NoError(t, err) + require.Empty(t, jobs) + + job, err := s.ScheduleOnce(jobKey1, time.Now().Add(100*time.Millisecond), nil) + require.NoError(t, err) + require.NotNil(t, job) + assert.NotEmpty(t, getVal(oncePrefix+jobKey1)) + assert.NotEmpty(t, s.activeJobs.jobs[jobKey1]) + s.pluginAPI.(*mockPluginAPI).setFailingWithPrefix(oncePrefix) + + // wait until the metadata has failed to read + time.Sleep((maxNumFails + 1) * (waitAfterFail + scheduleOnceJitter)) + assert.Equal(t, int32(0), atomic.LoadInt32(count1)) + assert.Nil(t, getVal(oncePrefix+jobKey1)) + + assert.Empty(t, s.activeJobs.jobs[jobKey1]) + assert.Empty(t, getVal(oncePrefix+jobKey1)) + assert.Equal(t, int32(0), atomic.LoadInt32(count1)) + + s.pluginAPI.(*mockPluginAPI).setFailingWithPrefix("") + }) + + t.Run("simulate starting the plugin with 3 pending jobs in the db", func(t *testing.T) { + resetScheduler() + + jobKeys := make(map[string]*int32) + for i := 0; i < 3; i++ { + jobKeys[makeKey()] = new(int32) + } + + callback := func(key string, _ any) { + count, ok := jobKeys[key] + if ok { + atomic.AddInt32(count, 1) + } + } + err := s.SetCallback(callback) + require.NoError(t, err) + err = s.Start() + require.NoError(t, err) + + for k := range jobKeys { + job, err3 := newJobOnce(s.pluginAPI, k, time.Now().Add(100*time.Millisecond), s.storedCallback, s.activeJobs, nil) + require.NoError(t, err3) + err3 = job.saveMetadata() + require.NoError(t, err3) + assert.NotEmpty(t, getVal(oncePrefix+k)) + } + + // double checking they're in the db: + jobs, err := s.ListScheduledJobs() + require.NoError(t, err) + require.Len(t, jobs, 3) + + // simulate starting the plugin + require.NoError(t, err) + err = s.scheduleNewJobsFromDB() + require.NoError(t, err) + + time.Sleep(120*time.Millisecond + scheduleOnceJitter) + + for k, v := range jobKeys { + assert.Empty(t, getVal(oncePrefix+k)) + assert.Empty(t, s.activeJobs.jobs[k]) + assert.Equal(t, int32(1), *v) + } + jobs, err = s.ListScheduledJobs() + require.NoError(t, err) + require.Empty(t, jobs) + }) + + t.Run("starting a job and polling before it's finished results in only one job running", func(t *testing.T) { + resetScheduler() + + jobKey := makeKey() + count := new(int32) + + callback := func(key string, _ any) { + if key == jobKey { + atomic.AddInt32(count, 1) + } + } + + err := s.SetCallback(callback) + require.NoError(t, err) + err = s.Start() + require.NoError(t, err) + + jobs, err := s.ListScheduledJobs() + require.NoError(t, err) + require.Empty(t, jobs) + + job, err := s.ScheduleOnce(jobKey, time.Now().Add(100*time.Millisecond), nil) + require.NoError(t, err) + require.NotNil(t, job) + assert.NotEmpty(t, getVal(oncePrefix+jobKey)) + s.activeJobs.mu.Lock() + assert.NotEmpty(t, s.activeJobs.jobs[jobKey]) + assert.Len(t, s.activeJobs.jobs, 1) + s.activeJobs.mu.Unlock() + + // simulate what the polling function will do for a long running job: + err = s.scheduleNewJobsFromDB() + require.NoError(t, err) + err = s.scheduleNewJobsFromDB() + require.NoError(t, err) + err = s.scheduleNewJobsFromDB() + require.NoError(t, err) + assert.NotEmpty(t, getVal(oncePrefix+jobKey)) + s.activeJobs.mu.Lock() + assert.NotEmpty(t, s.activeJobs.jobs[jobKey]) + assert.Len(t, s.activeJobs.jobs, 1) + s.activeJobs.mu.Unlock() + + // now wait for it to complete + time.Sleep(120*time.Millisecond + scheduleOnceJitter) + assert.Equal(t, int32(1), atomic.LoadInt32(count)) + assert.Empty(t, getVal(oncePrefix+jobKey)) + s.activeJobs.mu.Lock() + assert.Empty(t, s.activeJobs.jobs) + s.activeJobs.mu.Unlock() + }) + + t.Run("starting the same job again while it's still active will fail", func(t *testing.T) { + resetScheduler() + + jobKey := makeKey() + count := new(int32) + + callback := func(key string, _ any) { + if key == jobKey { + atomic.AddInt32(count, 1) + } + } + + err := s.SetCallback(callback) + require.NoError(t, err) + err = s.Start() + require.NoError(t, err) + + jobs, err := s.ListScheduledJobs() + require.NoError(t, err) + require.Empty(t, jobs) + + job, err := s.ScheduleOnce(jobKey, time.Now().Add(100*time.Millisecond), nil) + require.NoError(t, err) + require.NotNil(t, job) + assert.NotEmpty(t, getVal(oncePrefix+jobKey)) + assert.NotEmpty(t, s.activeJobs.jobs[jobKey]) + assert.Len(t, s.activeJobs.jobs, 1) + + // a plugin tries to start the same jobKey again: + job, err = s.ScheduleOnce(jobKey, time.Now().Add(10000*time.Millisecond), nil) + require.Error(t, err) + require.Nil(t, job) + + // now wait for first job to complete + time.Sleep(120*time.Millisecond + scheduleOnceJitter) + assert.Equal(t, int32(1), atomic.LoadInt32(count)) + assert.Empty(t, getVal(oncePrefix+jobKey)) + assert.Empty(t, s.activeJobs.jobs) + }) + + t.Run("simulate HA: canceling and setting a job with a different time--old one shouldn't fire", func(t *testing.T) { + resetScheduler() + + key := makeKey() + jobKeys := make(map[string]*int32) + jobKeys[key] = new(int32) + + // control is like the "control group" in an experiment. It will be overwritten, + // but with the same runAt. It should fire. + control := makeKey() + jobKeys[control] = new(int32) + + callback := func(key string, _ any) { + count, ok := jobKeys[key] + if ok { + atomic.AddInt32(count, 1) + } + } + err := s.SetCallback(callback) + require.NoError(t, err) + err = s.Start() + require.NoError(t, err) + + originalRunAt := time.Now().Add(100 * time.Millisecond) + newRunAt := time.Now().Add(101 * time.Millisecond) + + // store original + job, err := newJobOnce(s.pluginAPI, key, originalRunAt, s.storedCallback, s.activeJobs, nil) + require.NoError(t, err) + err = job.saveMetadata() + require.NoError(t, err) + assert.NotEmpty(t, getVal(oncePrefix+key)) + + // store oringal control + job2, err := newJobOnce(s.pluginAPI, control, originalRunAt, s.storedCallback, s.activeJobs, nil) + require.NoError(t, err) + err = job2.saveMetadata() + require.NoError(t, err) + assert.NotEmpty(t, getVal(oncePrefix+control)) + + // double checking originals are in the db: + jobs, err := s.ListScheduledJobs() + require.NoError(t, err) + require.Len(t, jobs, 2) + require.True(t, originalRunAt.Equal(jobs[0].RunAt)) + require.True(t, originalRunAt.Equal(jobs[1].RunAt)) + + // simulate starting the plugin + require.NoError(t, err) + err = s.scheduleNewJobsFromDB() + require.NoError(t, err) + + // Now "cancel" the original and make a new job with the same key but a different time. + // However, because we have only one list of synced jobs, we can't make two jobs with the + // same key. So we'll simulate this by changing the job metadata in the db. When the original + // job fires, it should see that the runAt is different, and it will think it has been canceled. + err = setMetadata(key, JobOnceMetadata{ + Key: key, + RunAt: newRunAt, + }) + require.NoError(t, err) + + // overwrite the control with the same runAt. It should fire. + err = setMetadata(control, JobOnceMetadata{ + Key: control, + RunAt: originalRunAt, + }) + require.NoError(t, err) + + time.Sleep(120*time.Millisecond + scheduleOnceJitter) + + // original job didn't fire the callback: + assert.Empty(t, getVal(oncePrefix+key)) + assert.Empty(t, s.activeJobs.jobs[key]) + assert.Equal(t, int32(0), *jobKeys[key]) + + // control job did fire the callback: + assert.Empty(t, getVal(oncePrefix+control)) + assert.Empty(t, s.activeJobs.jobs[control]) + assert.Equal(t, int32(1), *jobKeys[control]) + + jobs, err = s.ListScheduledJobs() + require.NoError(t, err) + require.Empty(t, jobs) + }) +} + +func TestScheduleOnceProps(t *testing.T) { + t.Run("confirm props are returned", func(t *testing.T) { + s := GetJobOnceScheduler(newMockPluginAPI(t)) + + jobKey := model.NewId() + jobProps := struct { + Foo string + }{ + Foo: "some foo", + } + + var mut sync.Mutex + var called bool + callback := func(key string, props any) { + require.Equal(t, jobKey, key) + require.Equal(t, jobProps, props) + mut.Lock() + defer mut.Unlock() + called = true + } + + err := s.SetCallback(callback) + require.NoError(t, err) + if !s.started { + err = s.Start() + require.NoError(t, err) + } + + _, err = s.ScheduleOnce(jobKey, time.Now().Add(100*time.Millisecond), jobProps) + require.NoError(t, err) + + // Check if callback was called + require.Eventually(t, func() bool { mut.Lock(); defer mut.Unlock(); return called }, time.Second, 50*time.Millisecond) + }) + + t.Run("props to large", func(t *testing.T) { + s := GetJobOnceScheduler(newMockPluginAPI(t)) + + props := make([]byte, propsLimit) + for i := 0; i < propsLimit; i++ { + props[i] = 'a' + } + + _, err := s.ScheduleOnce(model.NewId(), time.Now().Add(100*time.Millisecond), props) + require.Error(t, err) + }) +} diff --git a/server/public/pluginapi/cluster/job_test.go b/server/public/pluginapi/cluster/job_test.go new file mode 100644 index 0000000000..49019ae800 --- /dev/null +++ b/server/public/pluginapi/cluster/job_test.go @@ -0,0 +1,412 @@ +package cluster + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMakeWaitForInterval(t *testing.T) { + t.Run("panics on invalid interval", func(t *testing.T) { + assert.Panics(t, func() { + MakeWaitForInterval(0) + }) + }) + + const neverRun = -1 * time.Second + + testCases := []struct { + Description string + Interval time.Duration + LastFinished time.Duration + Expected time.Duration + }{ + { + "never run, 5 minutes", + 5 * time.Minute, + neverRun, + 0, + }, + { + "run 1 minute ago, 5 minutes", + 5 * time.Minute, + -1 * time.Minute, + 4 * time.Minute, + }, + { + "run 2 minutes ago, 5 minutes", + 5 * time.Minute, + -2 * time.Minute, + 3 * time.Minute, + }, + { + "run 4 minutes 30 seconds ago, 5 minutes", + 5 * time.Minute, + -4*time.Minute - 30*time.Second, + 30 * time.Second, + }, + { + "run 4 minutes 59 seconds ago, 5 minutes", + 5 * time.Minute, + -4*time.Minute - 59*time.Second, + 1 * time.Second, + }, + { + "never run, 1 hour", + 1 * time.Hour, + neverRun, + 0, + }, + { + "run 1 minute ago, 1 hour", + 1 * time.Hour, + -1 * time.Minute, + 59 * time.Minute, + }, + { + "run 20 minutes ago, 1 hour", + 1 * time.Hour, + -20 * time.Minute, + 40 * time.Minute, + }, + { + "run 55 minutes 30 seconds ago, 1 hour", + 1 * time.Hour, + -55*time.Minute - 30*time.Second, + 4*time.Minute + 30*time.Second, + }, + { + "run 59 minutes 59 seconds ago, 1 hour", + 1 * time.Hour, + -59*time.Minute - 59*time.Second, + 1 * time.Second, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.Description, func(t *testing.T) { + now := time.Now() + + var lastFinished time.Time + if testCase.LastFinished != neverRun { + lastFinished = now.Add(testCase.LastFinished) + } + + actual := MakeWaitForInterval(testCase.Interval)(now, JobMetadata{ + LastFinished: lastFinished, + }) + assert.Equal(t, testCase.Expected, actual) + }) + } +} + +func TestMakeWaitForRoundedInterval(t *testing.T) { + t.Run("panics on invalid interval", func(t *testing.T) { + assert.Panics(t, func() { + MakeWaitForRoundedInterval(0) + }) + }) + + const neverRun = -1 * time.Second + topOfTheHour := time.Now().Truncate(1 * time.Hour) + topOfTheDay := time.Now().Truncate(24 * time.Hour) + + testCases := []struct { + Description string + Interval time.Duration + Now time.Time + LastFinished time.Duration + Expected time.Duration + }{ + { + "5 minutes, top of the hour, never run", + 5 * time.Minute, + topOfTheHour, + neverRun, + 0, + }, + { + "5 minutes, top of the hour less 1 minute, never run", + 5 * time.Minute, + topOfTheHour.Add(-1 * time.Minute), + neverRun, + 0, + }, + { + "5 minutes, top of the hour less 1 minute, run 1 minute ago", + 5 * time.Minute, + topOfTheHour.Add(-1 * time.Minute), + -1 * time.Minute, + 1 * time.Minute, + }, + { + "5 minutes, top of the hour plus 1 minute, run 2 minutes ago", + 5 * time.Minute, + topOfTheHour.Add(1 * time.Minute), + -2 * time.Minute, + 0, + }, + { + "5 minutes, top of the hour plus 1 minute, run 30 seconds ago", + 5 * time.Minute, + topOfTheHour.Add(1 * time.Minute), + -30 * time.Second, + 4 * time.Minute, + }, + { + "5 minutes, top of the hour plus 7 minutes, run 30 seconds ago", + 5 * time.Minute, + topOfTheHour.Add(7 * time.Minute), + -30 * time.Second, + 3 * time.Minute, + }, + { + "30 minutes, top of the hour, never run", + 30 * time.Minute, + topOfTheHour, + neverRun, + 0, + }, + { + "30 minutes, top of the hour less 1 minute, never run", + 30 * time.Minute, + topOfTheHour.Add(-1 * time.Minute), + neverRun, + 0, + }, + { + "30 minutes, top of the hour less 1 minute, run 1 minute ago", + 30 * time.Minute, + topOfTheHour.Add(-1 * time.Minute), + -1 * time.Minute, + 1 * time.Minute, + }, + { + "30 minutes, top of the hour plus 1 minute, run 2 minutes ago", + 30 * time.Minute, + topOfTheHour.Add(1 * time.Minute), + -2 * time.Minute, + 0, + }, + { + "30 minutes, top of the hour plus 1 minute, run 30 seconds ago", + 30 * time.Minute, + topOfTheHour.Add(1 * time.Minute), + -30 * time.Second, + 29 * time.Minute, + }, + { + "30 minutes, top of the hour plus 7 minutes, run 30 seconds ago", + 30 * time.Minute, + topOfTheHour.Add(7 * time.Minute), + -30 * time.Second, + 23 * time.Minute, + }, + { + "24 hours, top of the day, never run", + 24 * time.Hour, + topOfTheDay, + neverRun, + 0, + }, + { + "24 hours, top of the day less 1 minute, never run", + 24 * time.Hour, + topOfTheDay.Add(-1 * time.Minute), + neverRun, + 0, + }, + { + "24 hours, top of the day less 1 minute, run 1 minute ago", + 24 * time.Hour, + topOfTheDay.Add(-1 * time.Minute), + -1 * time.Minute, + 1 * time.Minute, + }, + { + "24 hours, top of the day plus 1 minute, run 2 minutes ago", + 24 * time.Hour, + topOfTheDay.Add(1 * time.Minute), + -2 * time.Minute, + 0, + }, + { + "24 hours, top of the day plus 1 minute, run 30 seconds ago", + 24 * time.Hour, + topOfTheDay.Add(1 * time.Minute), + -30 * time.Second, + 23*time.Hour + 59*time.Minute, + }, + { + "24 hours, top of the day plus 7 minutes, run 30 seconds ago", + 24 * time.Hour, + topOfTheDay.Add(7 * time.Minute), + -30 * time.Second, + 23*time.Hour + 53*time.Minute, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.Description, func(t *testing.T) { + var lastFinished time.Time + if testCase.LastFinished != neverRun { + lastFinished = testCase.Now.Add(testCase.LastFinished) + } + + actual := MakeWaitForRoundedInterval(testCase.Interval)(testCase.Now, JobMetadata{ + LastFinished: lastFinished, + }) + assert.Equal(t, testCase.Expected, actual) + }) + } +} + +func TestSchedule(t *testing.T) { + t.Parallel() + + makeKey := model.NewId + + t.Run("single-threaded", func(t *testing.T) { + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) + + count := new(int32) + callback := func() { + atomic.AddInt32(count, 1) + } + + job, err := Schedule(mockPluginAPI, makeKey(), MakeWaitForInterval(100*time.Millisecond), callback) + require.NoError(t, err) + require.NotNil(t, job) + + time.Sleep(1 * time.Second) + + err = job.Close() + require.NoError(t, err) + + time.Sleep(1 * time.Second) + + // Shouldn't have hit 20 in this time frame + assert.Less(t, *count, int32(20)) + + // Should have hit at least 5 in this time frame + assert.Greater(t, *count, int32(5)) + }) + + t.Run("multi-threaded, single job", func(t *testing.T) { + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) + + count := new(int32) + callback := func() { + atomic.AddInt32(count, 1) + } + + var jobs []*Job + + key := makeKey() + + for i := 0; i < 3; i++ { + job, err := Schedule(mockPluginAPI, key, MakeWaitForInterval(100*time.Millisecond), callback) + require.NoError(t, err) + require.NotNil(t, job) + + jobs = append(jobs, job) + } + + time.Sleep(1 * time.Second) + + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + job := jobs[i] + wg.Add(1) + go func() { + defer wg.Done() + err := job.Close() + require.NoError(t, err) + }() + } + wg.Wait() + + time.Sleep(1 * time.Second) + + // Shouldn't have hit 20 in this time frame + assert.Less(t, *count, int32(20)) + + // Should have hit at least 5 in this time frame + assert.Greater(t, *count, int32(5)) + }) + + t.Run("multi-threaded, multiple jobs", func(t *testing.T) { + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) + + countA := new(int32) + callbackA := func() { + atomic.AddInt32(countA, 1) + } + + countB := new(int32) + callbackB := func() { + atomic.AddInt32(countB, 1) + } + + keyA := makeKey() + keyB := makeKey() + + var jobs []*Job + for i := 0; i < 3; i++ { + var key string + var callback func() + if i <= 1 { + key = keyA + callback = callbackA + } else { + key = keyB + callback = callbackB + } + + job, err := Schedule(mockPluginAPI, key, MakeWaitForInterval(100*time.Millisecond), callback) + require.NoError(t, err) + require.NotNil(t, job) + + jobs = append(jobs, job) + } + + time.Sleep(1 * time.Second) + + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + job := jobs[i] + wg.Add(1) + go func() { + defer wg.Done() + err := job.Close() + require.NoError(t, err) + }() + } + wg.Wait() + + time.Sleep(1 * time.Second) + + // Shouldn't have hit 20 in this time frame + assert.Less(t, *countA, int32(20)) + + // Should have hit at least 5 in this time frame + assert.Greater(t, *countA, int32(5)) + + // Shouldn't have hit 20 in this time frame + assert.Less(t, *countB, int32(20)) + + // Should have hit at least 5 in this time frame + assert.Greater(t, *countB, int32(5)) + }) +} diff --git a/server/public/pluginapi/cluster/mock_plugin_api_test.go b/server/public/pluginapi/cluster/mock_plugin_api_test.go new file mode 100644 index 0000000000..5a5487a4ae --- /dev/null +++ b/server/public/pluginapi/cluster/mock_plugin_api_test.go @@ -0,0 +1,150 @@ +package cluster + +import ( + "bytes" + "sort" + "strings" + "sync" + "testing" + + "github.com/mattermost/mattermost/server/public/model" +) + +type mockPluginAPI struct { + t *testing.T + + lock sync.Mutex + keyValues map[string][]byte + failing bool + failingWithPrefix string +} + +func newMockPluginAPI(t *testing.T) *mockPluginAPI { + return &mockPluginAPI{ + t: t, + keyValues: make(map[string][]byte), + } +} + +func (pluginAPI *mockPluginAPI) setFailing(failing bool) { + pluginAPI.lock.Lock() + defer pluginAPI.lock.Unlock() + + pluginAPI.failing = failing +} + +func (pluginAPI *mockPluginAPI) setFailingWithPrefix(prefix string) { + pluginAPI.lock.Lock() + defer pluginAPI.lock.Unlock() + + pluginAPI.failingWithPrefix = prefix +} + +func (pluginAPI *mockPluginAPI) clear() { + pluginAPI.lock.Lock() + defer pluginAPI.lock.Unlock() + + for k := range pluginAPI.keyValues { + delete(pluginAPI.keyValues, k) + } +} + +func (pluginAPI *mockPluginAPI) KVGet(key string) ([]byte, *model.AppError) { + pluginAPI.lock.Lock() + defer pluginAPI.lock.Unlock() + + if pluginAPI.failing { + return nil, &model.AppError{Message: "fake error"} + } + + if pluginAPI.failingWithPrefix != "" && strings.HasPrefix(key, pluginAPI.failingWithPrefix) { + return nil, &model.AppError{Message: "fake error for prefix " + pluginAPI.failingWithPrefix} + } + + return pluginAPI.keyValues[key], nil +} + +func (pluginAPI *mockPluginAPI) KVDelete(key string) *model.AppError { + pluginAPI.lock.Lock() + defer pluginAPI.lock.Unlock() + + if pluginAPI.failing { + return &model.AppError{Message: "fake error"} + } + + if pluginAPI.failingWithPrefix != "" && strings.HasPrefix(key, pluginAPI.failingWithPrefix) { + return &model.AppError{Message: "fake error for prefix " + pluginAPI.failingWithPrefix} + } + + delete(pluginAPI.keyValues, key) + + return nil +} + +func (pluginAPI *mockPluginAPI) KVList(page, count int) ([]string, *model.AppError) { + pluginAPI.lock.Lock() + defer pluginAPI.lock.Unlock() + + if pluginAPI.failing { + return nil, &model.AppError{Message: "fake error"} + } + + keys := make([]string, 0, len(pluginAPI.keyValues)) + for k := range pluginAPI.keyValues { + keys = append(keys, k) + } + + // have to sort, because we're paging below + sort.Strings(keys) + + start := min(page*count, len(keys)) + end := min((page+1)*count, len(keys)) + return keys[start:end], nil +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func (pluginAPI *mockPluginAPI) KVSetWithOptions(key string, value []byte, options model.PluginKVSetOptions) (bool, *model.AppError) { + pluginAPI.lock.Lock() + defer pluginAPI.lock.Unlock() + + if pluginAPI.failing { + return false, &model.AppError{Message: "fake error"} + } + + if pluginAPI.failingWithPrefix != "" && strings.HasPrefix(key, pluginAPI.failingWithPrefix) { + return false, &model.AppError{Message: "fake error for prefix " + pluginAPI.failingWithPrefix} + } + + if options.Atomic { + if actualValue := pluginAPI.keyValues[key]; !bytes.Equal(actualValue, options.OldValue) { + return false, nil + } + } + + if value == nil { + delete(pluginAPI.keyValues, key) + } else { + pluginAPI.keyValues[key] = value + } + + return true, nil +} + +func (pluginAPI *mockPluginAPI) LogError(msg string, keyValuePairs ...interface{}) { + if pluginAPI.t == nil { + return + } + + pluginAPI.t.Helper() + + params := []interface{}{msg} + params = append(params, keyValuePairs...) + + pluginAPI.t.Log(params...) +} diff --git a/server/public/pluginapi/cluster/mutex.go b/server/public/pluginapi/cluster/mutex.go new file mode 100644 index 0000000000..0169d3b398 --- /dev/null +++ b/server/public/pluginapi/cluster/mutex.go @@ -0,0 +1,185 @@ +package cluster + +import ( + "context" + "sync" + "time" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/pkg/errors" +) + +const ( + // mutexPrefix is used to namespace key values created for a mutex from other key values + // created by a plugin. + mutexPrefix = "mutex_" +) + +const ( + // ttl is the interval after which a locked mutex will expire unless refreshed + ttl = time.Second * 15 + + // refreshInterval is the interval on which the mutex will be refreshed when locked + refreshInterval = ttl / 2 +) + +// MutexPluginAPI is the plugin API interface required to manage mutexes. +type MutexPluginAPI interface { + KVSetWithOptions(key string, value []byte, options model.PluginKVSetOptions) (bool, *model.AppError) + LogError(msg string, keyValuePairs ...interface{}) +} + +// Mutex is similar to sync.Mutex, except usable by multiple plugin instances across a cluster. +// +// Internally, a mutex relies on an atomic key-value set operation as exposed by the Mattermost +// plugin API. +// +// Mutexes with different names are unrelated. Mutexes with the same name from different plugins +// are unrelated. Pick a unique name for each mutex your plugin requires. +// +// A Mutex must not be copied after first use. +type Mutex struct { + pluginAPI MutexPluginAPI + key string + + // lock guards the variables used to manage the refresh task, and is not itself related to + // the cluster-wide lock. + lock sync.Mutex + stopRefresh chan bool + refreshDone chan bool +} + +// NewMutex creates a mutex with the given key name. +// +// Panics if key is empty. +func NewMutex(pluginAPI MutexPluginAPI, key string) (*Mutex, error) { + key, err := makeLockKey(key) + if err != nil { + return nil, err + } + + return &Mutex{ + pluginAPI: pluginAPI, + key: key, + }, nil +} + +// makeLockKey returns the prefixed key used to namespace mutex keys. +func makeLockKey(key string) (string, error) { + if key == "" { + return "", errors.New("must specify valid mutex key") + } + + return mutexPrefix + key, nil +} + +// lock makes a single attempt to atomically lock the mutex, returning true only if successful. +func (m *Mutex) tryLock() (bool, error) { + ok, err := m.pluginAPI.KVSetWithOptions(m.key, []byte{1}, model.PluginKVSetOptions{ + Atomic: true, + OldValue: nil, // No existing key value. + ExpireInSeconds: int64(ttl / time.Second), + }) + if err != nil { + return false, errors.Wrap(err, "failed to set mutex kv") + } + + return ok, nil +} + +// refreshLock rewrites the lock key value with a new expiry, returning true only if successful. +func (m *Mutex) refreshLock() error { + ok, err := m.pluginAPI.KVSetWithOptions(m.key, []byte{1}, model.PluginKVSetOptions{ + Atomic: true, + OldValue: []byte{1}, + ExpireInSeconds: int64(ttl / time.Second), + }) + if err != nil { + return errors.Wrap(err, "failed to refresh mutex kv") + } else if !ok { + return errors.New("unexpectedly failed to refresh mutex kv") + } + + return nil +} + +// Lock locks m. If the mutex is already locked by any plugin instance, including the current one, +// the calling goroutine blocks until the mutex can be locked. +func (m *Mutex) Lock() { + _ = m.LockWithContext(context.Background()) +} + +// LockWithContext locks m unless the context is canceled. If the mutex is already locked by any plugin +// instance, including the current one, the calling goroutine blocks until the mutex can be locked, +// or the context is canceled. +// +// The mutex is locked only if a nil error is returned. +func (m *Mutex) LockWithContext(ctx context.Context) error { + var waitInterval time.Duration + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(waitInterval): + } + + locked, err := m.tryLock() + if err != nil { + m.pluginAPI.LogError("failed to lock mutex", "err", err, "lock_key", m.key) + waitInterval = nextWaitInterval(waitInterval, err) + continue + } else if !locked { + waitInterval = nextWaitInterval(waitInterval, err) + continue + } + + stop := make(chan bool) + done := make(chan bool) + go func() { + defer close(done) + t := time.NewTicker(refreshInterval) + for { + select { + case <-t.C: + err := m.refreshLock() + if err != nil { + m.pluginAPI.LogError("failed to refresh mutex", "err", err, "lock_key", m.key) + return + } + case <-stop: + return + } + } + }() + + m.lock.Lock() + m.stopRefresh = stop + m.refreshDone = done + m.lock.Unlock() + + return nil + } +} + +// Unlock unlocks m. It is a run-time error if m is not locked on entry to Unlock. +// +// Just like sync.Mutex, a locked Lock is not associated with a particular goroutine or plugin +// instance. It is allowed for one goroutine or plugin instance to lock a Lock and then arrange +// for another goroutine or plugin instance to unlock it. In practice, ownership of the lock should +// remain within a single plugin instance. +func (m *Mutex) Unlock() { + m.lock.Lock() + if m.stopRefresh == nil { + m.lock.Unlock() + panic("mutex has not been acquired") + } + + close(m.stopRefresh) + m.stopRefresh = nil + <-m.refreshDone + m.lock.Unlock() + + // If an error occurs deleting, the mutex kv will still expire, allowing later retry. + _, _ = m.pluginAPI.KVSetWithOptions(m.key, nil, model.PluginKVSetOptions{}) +} diff --git a/server/public/pluginapi/cluster/mutex_example_test.go b/server/public/pluginapi/cluster/mutex_example_test.go new file mode 100644 index 0000000000..8db36c5023 --- /dev/null +++ b/server/public/pluginapi/cluster/mutex_example_test.go @@ -0,0 +1,20 @@ +package cluster_test + +import ( + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/pluginapi/cluster" +) + +//nolint:staticcheck +func ExampleMutex() { + // Use p.API from your plugin instead. + pluginAPI := plugin.API(nil) + + m, err := cluster.NewMutex(pluginAPI, "key") + if err != nil { + panic(err) + } + m.Lock() + // critical section + m.Unlock() +} diff --git a/server/public/pluginapi/cluster/mutex_test.go b/server/public/pluginapi/cluster/mutex_test.go new file mode 100644 index 0000000000..d415ee736b --- /dev/null +++ b/server/public/pluginapi/cluster/mutex_test.go @@ -0,0 +1,276 @@ +package cluster + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" +) + +func mustNewMutex(pluginAPI MutexPluginAPI, key string) *Mutex { + m, err := NewMutex(pluginAPI, key) + if err != nil { + panic(err) + } + + return m +} + +func TestMakeLockKey(t *testing.T) { + t.Run("fails when empty", func(t *testing.T) { + key, err := makeLockKey("") + assert.Error(t, err) + assert.Empty(t, key) + }) + + t.Run("not-empty", func(t *testing.T) { + testCases := map[string]string{ + "key": mutexPrefix + "key", + "other": mutexPrefix + "other", + } + + for key, expected := range testCases { + actual, err := makeLockKey(key) + require.NoError(t, err) + assert.Equal(t, expected, actual) + } + }) +} + +func lock(t *testing.T, m *Mutex) { + t.Helper() + + done := make(chan bool) + go func() { + t.Helper() + + defer close(done) + m.Lock() + }() + + select { + case <-time.After(1 * time.Second): + require.Fail(t, "failed to lock mutex within 1 second") + case <-done: + } +} + +func unlock(t *testing.T, m *Mutex, panics bool) { + t.Helper() + + done := make(chan bool) + go func() { + t.Helper() + + defer close(done) + if panics { + assert.Panics(t, m.Unlock) + } else { + assert.NotPanics(t, m.Unlock) + } + }() + + select { + case <-time.After(1 * time.Second): + require.Fail(t, "failed to unlock mutex within 1 second") + case <-done: + } +} + +func TestMutex(t *testing.T) { + t.Parallel() + + makeKey := model.NewId + + t.Run("successful lock/unlock cycle", func(t *testing.T) { + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) + + m := mustNewMutex(mockPluginAPI, makeKey()) + lock(t, m) + unlock(t, m, false) + lock(t, m) + unlock(t, m, false) + }) + + t.Run("unlock when not locked", func(t *testing.T) { + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) + + m := mustNewMutex(mockPluginAPI, makeKey()) + unlock(t, m, true) + }) + + t.Run("blocking lock", func(t *testing.T) { + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) + + m := mustNewMutex(mockPluginAPI, makeKey()) + lock(t, m) + + done := make(chan bool) + go func() { + defer close(done) + m.Lock() + }() + + select { + case <-time.After(1 * time.Second): + case <-done: + require.Fail(t, "second goroutine should not have locked") + } + + unlock(t, m, false) + + select { + case <-time.After(pollWaitInterval * 2): + require.Fail(t, "second goroutine should have locked") + case <-done: + } + }) + + t.Run("failed lock", func(t *testing.T) { + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) + + m := mustNewMutex(mockPluginAPI, makeKey()) + + mockPluginAPI.setFailing(true) + + done := make(chan bool) + go func() { + defer close(done) + m.Lock() + }() + + select { + case <-time.After(5 * time.Second): + case <-done: + require.Fail(t, "goroutine should not have locked") + } + + mockPluginAPI.setFailing(false) + + select { + case <-time.After(15 * time.Second): + require.Fail(t, "goroutine should have locked") + case <-done: + } + }) + + t.Run("failed unlock", func(t *testing.T) { + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) + + key := makeKey() + m := mustNewMutex(mockPluginAPI, key) + lock(t, m) + + mockPluginAPI.setFailing(true) + + unlock(t, m, false) + + // Simulate expiry + mockPluginAPI.clear() + mockPluginAPI.setFailing(false) + + lock(t, m) + }) + + t.Run("discrete keys", func(t *testing.T) { + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) + + m1 := mustNewMutex(mockPluginAPI, makeKey()) + lock(t, m1) + + m2 := mustNewMutex(mockPluginAPI, makeKey()) + lock(t, m2) + + m3 := mustNewMutex(mockPluginAPI, makeKey()) + lock(t, m3) + + unlock(t, m1, false) + unlock(t, m3, false) + + lock(t, m1) + + unlock(t, m2, false) + unlock(t, m1, false) + }) + + t.Run("with uncancelled context", func(t *testing.T) { + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) + + key := makeKey() + m := mustNewMutex(mockPluginAPI, key) + + m.Lock() + + ctx := context.Background() + done := make(chan bool) + go func() { + defer close(done) + err := m.LockWithContext(ctx) + require.Nil(t, err) + }() + + select { + case <-time.After(ttl + pollWaitInterval*2): + case <-done: + require.Fail(t, "goroutine should not have locked") + } + + m.Unlock() + + select { + case <-time.After(pollWaitInterval * 2): + require.Fail(t, "goroutine should have locked after unlock") + case <-done: + } + }) + + t.Run("with canceled context", func(t *testing.T) { + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) + + m := mustNewMutex(mockPluginAPI, makeKey()) + + m.Lock() + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan bool) + go func() { + defer close(done) + err := m.LockWithContext(ctx) + require.NotNil(t, err) + }() + + select { + case <-time.After(ttl + pollWaitInterval*2): + case <-done: + require.Fail(t, "goroutine should not have locked") + } + + cancel() + + select { + case <-time.After(pollWaitInterval * 2): + require.Fail(t, "goroutine should have aborted after cancellation") + case <-done: + } + }) +} diff --git a/server/public/pluginapi/cluster/wait.go b/server/public/pluginapi/cluster/wait.go new file mode 100644 index 0000000000..bf62b4ac8d --- /dev/null +++ b/server/public/pluginapi/cluster/wait.go @@ -0,0 +1,43 @@ +package cluster + +import ( + "math/rand" + "time" +) + +const ( + // minWaitInterval is the minimum amount of time to wait between locking attempts + minWaitInterval = 1 * time.Second + + // maxWaitInterval is the maximum amount of time to wait between locking attempts + maxWaitInterval = 5 * time.Minute + + // pollWaitInterval is the usual time to wait between unsuccessful locking attempts + pollWaitInterval = 1 * time.Second + + // jitterWaitInterval is the amount of jitter to add when waiting to avoid thundering herds + jitterWaitInterval = minWaitInterval / 2 +) + +// nextWaitInterval determines how long to wait until the next lock retry. +func nextWaitInterval(lastWaitInterval time.Duration, err error) time.Duration { + nextWaitInterval := lastWaitInterval + + if nextWaitInterval <= 0 { + nextWaitInterval = minWaitInterval + } + + if err != nil { + nextWaitInterval *= 2 + if nextWaitInterval > maxWaitInterval { + nextWaitInterval = maxWaitInterval + } + } else { + nextWaitInterval = pollWaitInterval + } + + // Add some jitter to avoid unnecessary collision between competing plugin instances. + nextWaitInterval += time.Duration(rand.Int63n(int64(jitterWaitInterval)) - int64(jitterWaitInterval)/2) + + return nextWaitInterval +} diff --git a/server/public/pluginapi/cluster/wait_test.go b/server/public/pluginapi/cluster/wait_test.go new file mode 100644 index 0000000000..fb6f35ae92 --- /dev/null +++ b/server/public/pluginapi/cluster/wait_test.go @@ -0,0 +1,156 @@ +package cluster + +import ( + "testing" + "time" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func TestNextWaitInterval(t *testing.T) { + testCases := []struct { + Description string + lastWaitInterval time.Duration + err error + expectedRange [2]time.Duration + }{ + { + "0, no error", + 0, + nil, + [2]time.Duration{ + 1*time.Second - jitterWaitInterval/2, + 1*time.Second + jitterWaitInterval/2, + }, + }, + { + "0, error", + 0, + errors.New("test"), + [2]time.Duration{ + 2*time.Second - jitterWaitInterval/2, + 2*time.Second + jitterWaitInterval/2, + }, + }, + { + "negative, no error", + -100 * time.Second, + nil, + [2]time.Duration{ + 1*time.Second - jitterWaitInterval/2, + 1*time.Second + jitterWaitInterval/2, + }, + }, + { + "negative, error", + -100 * time.Second, + errors.New("test"), + [2]time.Duration{ + 2*time.Second - jitterWaitInterval/2, + 2*time.Second + jitterWaitInterval/2, + }, + }, + { + "1 second, no error", + 1 * time.Second, + nil, + [2]time.Duration{ + 1*time.Second - jitterWaitInterval/2, + 1*time.Second + jitterWaitInterval/2, + }, + }, + { + "1 second, error", + 1 * time.Second, + errors.New("test"), + [2]time.Duration{ + 2*time.Second - jitterWaitInterval/2, + 2*time.Second + jitterWaitInterval/2, + }, + }, + { + "10 seconds, no error", + 10 * time.Second, + nil, + [2]time.Duration{ + 1*time.Second - jitterWaitInterval/2, + 1*time.Second + jitterWaitInterval/2, + }, + }, + { + "10 second, error", + 10 * time.Second, + errors.New("test"), + [2]time.Duration{ + 20*time.Second - jitterWaitInterval/2, + 20*time.Second + jitterWaitInterval/2, + }, + }, + { + "4 minutes, no error", + 4 * time.Minute, + nil, + [2]time.Duration{ + 1*time.Second - jitterWaitInterval/2, + 1*time.Second + jitterWaitInterval/2, + }, + }, + { + "4 minutes, error", + 4 * time.Minute, + errors.New("test"), + [2]time.Duration{ + 5*time.Minute - jitterWaitInterval/2, + 5*time.Minute + jitterWaitInterval/2, + }, + }, + { + "5 minutes, no error", + 5 * time.Minute, + nil, + [2]time.Duration{ + 1*time.Second - jitterWaitInterval/2, + 1*time.Second + jitterWaitInterval/2, + }, + }, + { + "5 minutes, error", + 5 * time.Minute, + errors.New("test"), + [2]time.Duration{ + 5*time.Minute - jitterWaitInterval/2, + 5*time.Minute + jitterWaitInterval/2, + }, + }, + { + "10minutes, no error", + 10 * time.Minute, + nil, + [2]time.Duration{ + 1*time.Second - jitterWaitInterval/2, + 1*time.Second + jitterWaitInterval/2, + }, + }, + { + "10minutes, error", + 10 * time.Minute, + errors.New("test"), + [2]time.Duration{ + 5*time.Minute - jitterWaitInterval/2, + 5*time.Minute + jitterWaitInterval/2, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.Description, func(t *testing.T) { + actualWaitInterval := nextWaitInterval( + testCase.lastWaitInterval, + testCase.err, + ) + assert.GreaterOrEqual(t, int64(actualWaitInterval), int64(testCase.expectedRange[0])) + assert.LessOrEqual(t, int64(actualWaitInterval), int64(testCase.expectedRange[1])) + }) + } +} diff --git a/server/public/pluginapi/cluster_test.go b/server/public/pluginapi/cluster_test.go new file mode 100644 index 0000000000..85348f2412 --- /dev/null +++ b/server/public/pluginapi/cluster_test.go @@ -0,0 +1,48 @@ +package pluginapi_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/mattermost/mattermost/server/public/pluginapi" +) + +func TestPublishPluginClusterEvent(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("PublishPluginClusterEvent", + model.PluginClusterEvent{Id: "someID", Data: []byte("foo")}, + model.PluginClusterEventSendOptions{SendType: model.PluginClusterEventSendTypeReliable}, + ).Return(nil) + + err := client.Cluster.PublishPluginEvent( + model.PluginClusterEvent{Id: "someID", Data: []byte("foo")}, + model.PluginClusterEventSendOptions{SendType: model.PluginClusterEventSendTypeReliable}, + ) + require.NoError(t, err) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("PublishPluginClusterEvent", + model.PluginClusterEvent{Id: "someID", Data: []byte("foo")}, + model.PluginClusterEventSendOptions{SendType: model.PluginClusterEventSendTypeReliable}, + ).Return(errors.New("someError")) + + err := client.Cluster.PublishPluginEvent( + model.PluginClusterEvent{Id: "someID", Data: []byte("foo")}, + model.PluginClusterEventSendOptions{SendType: model.PluginClusterEventSendTypeReliable}, + ) + require.Error(t, err) + }) +} diff --git a/server/public/pluginapi/configuration.go b/server/public/pluginapi/configuration.go new file mode 100644 index 0000000000..f28cfe55e8 --- /dev/null +++ b/server/public/pluginapi/configuration.go @@ -0,0 +1,55 @@ +package pluginapi + +import ( + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" +) + +// ConfigurationService exposes methods to manipulate the server and plugin configuration. +type ConfigurationService struct { + api plugin.API +} + +// LoadPluginConfiguration loads the plugin's configuration. dest should be a pointer to a +// struct to which the configuration JSON can be unmarshalled. +// +// Minimum server version: 5.2 +func (c *ConfigurationService) LoadPluginConfiguration(dest interface{}) error { + // TODO: Isn't this method redundant given GetPluginConfig() and even GetConfig()? + return c.api.LoadPluginConfiguration(dest) +} + +// GetConfig fetches the currently persisted config. +// +// Minimum server version: 5.2 +func (c *ConfigurationService) GetConfig() *model.Config { + return c.api.GetConfig() +} + +// GetUnsanitizedConfig fetches the currently persisted config without removing secrets. +// +// Minimum server version: 5.16 +func (c *ConfigurationService) GetUnsanitizedConfig() *model.Config { + return c.api.GetUnsanitizedConfig() +} + +// SaveConfig sets the given config and persists the changes +// +// Minimum server version: 5.2 +func (c *ConfigurationService) SaveConfig(cfg *model.Config) error { + return normalizeAppErr(c.api.SaveConfig(cfg)) +} + +// GetPluginConfig fetches the currently persisted config of plugin +// +// Minimum server version: 5.6 +func (c *ConfigurationService) GetPluginConfig() map[string]interface{} { + return c.api.GetPluginConfig() +} + +// SavePluginConfig sets the given config for plugin and persists the changes +// +// Minimum server version: 5.6 +func (c *ConfigurationService) SavePluginConfig(cfg map[string]interface{}) error { + return normalizeAppErr(c.api.SavePluginConfig(cfg)) +} diff --git a/server/public/pluginapi/email.go b/server/public/pluginapi/email.go new file mode 100644 index 0000000000..80a4d97c6e --- /dev/null +++ b/server/public/pluginapi/email.go @@ -0,0 +1,17 @@ +package pluginapi + +import ( + "github.com/mattermost/mattermost/server/public/plugin" +) + +// MailService exposes methods to send email. +type MailService struct { + api plugin.API +} + +// Send sends an email to a specific address. +// +// Minimum server version: 5.7 +func (m *MailService) Send(to, subject, htmlBody string) error { + return normalizeAppErr(m.api.SendMail(to, subject, htmlBody)) +} diff --git a/server/public/pluginapi/emoji.go b/server/public/pluginapi/emoji.go new file mode 100644 index 0000000000..ec21de80ae --- /dev/null +++ b/server/public/pluginapi/emoji.go @@ -0,0 +1,54 @@ +package pluginapi + +import ( + "bytes" + "io" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" +) + +// EmojiService exposes methods to manipulate emojis. +type EmojiService struct { + api plugin.API +} + +// Get gets a custom emoji by id. +// +// Minimum server version: 5.6 +func (e *EmojiService) Get(id string) (*model.Emoji, error) { + emoji, appErr := e.api.GetEmoji(id) + + return emoji, normalizeAppErr(appErr) +} + +// GetByName gets a custom emoji by its name. +// +// Minimum server version: 5.6 +func (e *EmojiService) GetByName(name string) (*model.Emoji, error) { + emoji, appErr := e.api.GetEmojiByName(name) + + return emoji, normalizeAppErr(appErr) +} + +// GetImage gets a custom emoji's content and format by id. +// +// Minimum server version: 5.6 +func (e *EmojiService) GetImage(id string) (io.Reader, string, error) { + contentBytes, format, appErr := e.api.GetEmojiImage(id) + if appErr != nil { + return nil, "", normalizeAppErr(appErr) + } + + return bytes.NewReader(contentBytes), format, nil +} + +// List retrieves a list of custom emojis. +// sortBy parameter can be: "name". +// +// Minimum server version: 5.6 +func (e *EmojiService) List(sortBy string, page, count int) ([]*model.Emoji, error) { + emojis, appErr := e.api.GetEmojiList(sortBy, page, count) + + return emojis, normalizeAppErr(appErr) +} diff --git a/server/public/pluginapi/emoji_test.go b/server/public/pluginapi/emoji_test.go new file mode 100644 index 0000000000..8bbe8e37d9 --- /dev/null +++ b/server/public/pluginapi/emoji_test.go @@ -0,0 +1,131 @@ +package pluginapi_test + +import ( + "io" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/mattermost/mattermost/server/public/pluginapi" +) + +func TestGetEmoji(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetEmoji", "1").Return(&model.Emoji{Id: "2"}, nil) + + emoji, err := client.Emoji.Get("1") + require.NoError(t, err) + require.Equal(t, "2", emoji.Id) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := newAppError() + + api.On("GetEmoji", "1").Return(nil, appErr) + + emoji, err := client.Emoji.Get("1") + require.Equal(t, appErr, err) + require.Zero(t, emoji) + }) +} + +func TestGetEmojiByName(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetEmojiByName", "1").Return(&model.Emoji{Id: "2"}, nil) + + emoji, err := client.Emoji.GetByName("1") + require.NoError(t, err) + require.Equal(t, "2", emoji.Id) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := newAppError() + + api.On("GetEmojiByName", "1").Return(nil, appErr) + + emoji, err := client.Emoji.GetByName("1") + require.Equal(t, appErr, err) + require.Zero(t, emoji) + }) +} + +func TestGetEmojiImage(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetEmojiImage", "1").Return([]byte{1}, "jpg", nil) + + content, format, err := client.Emoji.GetImage("1") + require.NoError(t, err) + contentBytes, err := io.ReadAll(content) + require.NoError(t, err) + require.Equal(t, []byte{1}, contentBytes) + require.Equal(t, "jpg", format) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := newAppError() + + api.On("GetEmojiImage", "1").Return(nil, "", appErr) + + content, format, err := client.Emoji.GetImage("1") + require.Equal(t, appErr, err) + require.Zero(t, content) + require.Zero(t, format) + }) +} + +func TestListEmojis(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetEmojiList", "1", 2, 3).Return([]*model.Emoji{ + {Id: "4"}, + }, nil) + + emojis, err := client.Emoji.List("1", 2, 3) + require.NoError(t, err) + require.Len(t, emojis, 1) + require.Equal(t, "4", emojis[0].Id) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := newAppError() + + api.On("GetEmojiList", "1", 2, 3).Return(nil, appErr) + + emojis, err := client.Emoji.List("1", 2, 3) + require.Equal(t, appErr, err) + require.Zero(t, emojis) + }) +} diff --git a/server/public/pluginapi/error.go b/server/public/pluginapi/error.go new file mode 100644 index 0000000000..32765915ef --- /dev/null +++ b/server/public/pluginapi/error.go @@ -0,0 +1,39 @@ +package pluginapi + +import ( + "net/http" + + "github.com/pkg/errors" + + "github.com/mattermost/mattermost/server/public/model" +) + +// ErrNotFound is returned by the plugin API when an object is not found. +var ErrNotFound = errors.New("not found") + +// normalizeAppErr returns a truly nil error if appErr is nil as well as normalizing a class +// of non-nil AppErrors to simplify use within plugins. +// +// This doesn't happen automatically when a *model.AppError is cast to an error, since the +// resulting error interface has a concrete type with a nil value. This leads to the seemingly +// impossible: +// +// var err error +// err = func() *model.AppError { return nil }() +// if err != nil { +// panic("err != nil, which surprises most") +// } +// +// Fix this problem for all plugin authors by normalizing to special case the handling of a nil +// *model.AppError. See https://golang.org/doc/faq#nil_error for more details. +func normalizeAppErr(appErr *model.AppError) error { + if appErr == nil { + return nil + } + + if appErr.StatusCode == http.StatusNotFound { + return ErrNotFound + } + + return appErr +} diff --git a/server/public/pluginapi/example_client_test.go b/server/public/pluginapi/example_client_test.go new file mode 100644 index 0000000000..c254c370bd --- /dev/null +++ b/server/public/pluginapi/example_client_test.go @@ -0,0 +1,21 @@ +package pluginapi_test + +import ( + "github.com/mattermost/mattermost/server/public/pluginapi" + + "github.com/mattermost/mattermost/server/public/plugin" +) + +type Plugin struct { + plugin.MattermostPlugin + client *pluginapi.Client +} + +func (p *Plugin) OnActivate() error { + p.client = pluginapi.NewClient(p.API, p.Driver) + + return nil +} + +func Example() { +} diff --git a/server/public/pluginapi/experimental/bot/bot.go b/server/public/pluginapi/experimental/bot/bot.go new file mode 100644 index 0000000000..76634866e5 --- /dev/null +++ b/server/public/pluginapi/experimental/bot/bot.go @@ -0,0 +1,53 @@ +// Copyright (c) 2019-present Mattermost, Inc. All Rights Reserved. +// See License for license information. + +package bot + +import ( + "github.com/pkg/errors" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/pluginapi" +) + +type Bot interface { + Ensure(stored *model.Bot, iconPath string) error + MattermostUserID() string + String() string +} + +type bot struct { + botService pluginapi.BotService + mattermostUserID string + displayName string +} + +func New(botService pluginapi.BotService) Bot { + newBot := &bot{ + botService: botService, + } + return newBot +} + +func (bot *bot) Ensure(stored *model.Bot, iconPath string) error { + if bot.mattermostUserID != "" { + // Already done + return nil + } + + botUserID, err := bot.botService.EnsureBot(stored, pluginapi.ProfileImagePath(iconPath)) + if err != nil { + return errors.Wrap(err, "failed to ensure bot account") + } + bot.mattermostUserID = botUserID + bot.displayName = stored.DisplayName + return nil +} + +func (bot *bot) MattermostUserID() string { + return bot.mattermostUserID +} + +func (bot *bot) String() string { + return bot.displayName +} diff --git a/server/public/pluginapi/experimental/bot/logger/admincclogger/admincc_logger.go b/server/public/pluginapi/experimental/bot/logger/admincclogger/admincc_logger.go new file mode 100644 index 0000000000..b3c549e75e --- /dev/null +++ b/server/public/pluginapi/experimental/bot/logger/admincclogger/admincc_logger.go @@ -0,0 +1,97 @@ +package admincclogger + +import ( + "fmt" + + "github.com/mattermost/mattermost/server/public/pluginapi/experimental/bot/logger" + "github.com/mattermost/mattermost/server/public/pluginapi/experimental/bot/poster" + "github.com/mattermost/mattermost/server/public/pluginapi/experimental/common" +) + +type adminCCLogger struct { + logger.Logger + dmer poster.DMer + logLevel logger.LogLevel + includeContext bool + userIDs []string +} + +/* +New promotes the provided logger into a admin cc logger, sending direct messages to all the admin +ids provided through the dmer provided, about all events below the logLevel. If logVerbose is set, +it will also send the context. + +- l Logger: A logger to promote. + +- dmer DMer: A DMer to send the messages to the admins. + +- logLevel: The highest type of message to be stored in telemetry. + +- includeContext: Whether the log context should be messaged to the admins. + +- userIDs: The user IDs of the admins. +*/ +func New(l logger.Logger, dmer poster.DMer, logLevel logger.LogLevel, includeContext bool, userIDs ...string) logger.Logger { + return &adminCCLogger{ + Logger: l, + dmer: dmer, + logLevel: logLevel, + includeContext: includeContext, + userIDs: userIDs, + } +} + +// NewFromAPI creates a adminCCLogger directly from a LogAPI instead of passing a logger. +func NewFromAPI(api common.LogAPI, dmer poster.DMer, logLevel logger.LogLevel, includeContext bool, userIDs ...string) logger.Logger { + return New(logger.New(api), dmer, logLevel, includeContext, userIDs...) +} + +func (l *adminCCLogger) Debugf(format string, args ...interface{}) { + l.Logger.Debugf(format, args...) + message := fmt.Sprintf(format, args...) + if logger.Level(l.logLevel) >= 4 { + l.logToAdmins("DEBUG", message) + } +} + +func (l *adminCCLogger) Errorf(format string, args ...interface{}) { + l.Logger.Errorf(format, args...) + message := fmt.Sprintf(format, args...) + if logger.Level(l.logLevel) >= 1 { + l.logToAdmins("ERROR", message) + } +} + +func (l *adminCCLogger) Infof(format string, args ...interface{}) { + l.Logger.Infof(format, args...) + message := fmt.Sprintf(format, args...) + if logger.Level(l.logLevel) >= 3 { + l.logToAdmins("INFO", message) + } +} + +func (l *adminCCLogger) Warnf(format string, args ...interface{}) { + l.Logger.Warnf(format, args...) + message := fmt.Sprintf(format, args...) + if logger.Level(l.logLevel) >= 2 { + l.logToAdmins("WARN", message) + } +} + +func (l *adminCCLogger) logToAdmins(level, message string) { + context := l.Context() + if l.includeContext && len(context) > 0 { + message += "\n" + common.JSONBlock(context) + } + _ = l.dmAdmins("(log " + level + ") " + message) +} + +func (l *adminCCLogger) dmAdmins(format string, args ...interface{}) error { + for _, id := range l.userIDs { + _, err := l.dmer.DM(id, format, args) + if err != nil { + return err + } + } + return nil +} diff --git a/server/public/pluginapi/experimental/bot/logger/default_logger.go b/server/public/pluginapi/experimental/bot/logger/default_logger.go new file mode 100644 index 0000000000..e51adf8069 --- /dev/null +++ b/server/public/pluginapi/experimental/bot/logger/default_logger.go @@ -0,0 +1,82 @@ +// Copyright (c) 2019-present Mattermost, Inc. All Rights Reserved. +// See License for license information. + +package logger + +import ( + "fmt" + "time" + + "github.com/mattermost/mattermost/server/public/pluginapi/experimental/common" +) + +type defaultLogger struct { + logContext LogContext + logAPI common.LogAPI +} + +/* +New creates a new logger. + +- api: LogAPI implementation +*/ +func New(api common.LogAPI) Logger { + l := &defaultLogger{ + logAPI: api, + } + return l +} + +func (l *defaultLogger) With(logContext LogContext) Logger { + newLogger := *l + if len(newLogger.logContext) == 0 { + newLogger.logContext = map[string]interface{}{} + } + for k, v := range logContext { + newLogger.logContext[k] = v + } + return &newLogger +} + +func (l *defaultLogger) WithError(err error) Logger { + newLogger := *l + if len(newLogger.logContext) == 0 { + newLogger.logContext = map[string]interface{}{} + } + newLogger.logContext[ErrorKey] = err.Error() + return &newLogger +} + +func (l *defaultLogger) Context() LogContext { + return l.logContext +} + +func (l *defaultLogger) Timed() Logger { + return l.With(LogContext{ + timed: time.Now(), + }) +} + +func (l *defaultLogger) Debugf(format string, args ...interface{}) { + measure(l.logContext) + message := fmt.Sprintf(format, args...) + l.logAPI.LogDebug(message, toKeyValuePairs(l.logContext)...) +} + +func (l *defaultLogger) Errorf(format string, args ...interface{}) { + measure(l.logContext) + message := fmt.Sprintf(format, args...) + l.logAPI.LogError(message, toKeyValuePairs(l.logContext)...) +} + +func (l *defaultLogger) Infof(format string, args ...interface{}) { + measure(l.logContext) + message := fmt.Sprintf(format, args...) + l.logAPI.LogInfo(message, toKeyValuePairs(l.logContext)...) +} + +func (l *defaultLogger) Warnf(format string, args ...interface{}) { + measure(l.logContext) + message := fmt.Sprintf(format, args...) + l.logAPI.LogWarn(message, toKeyValuePairs(l.logContext)...) +} diff --git a/server/public/pluginapi/experimental/bot/logger/logger.go b/server/public/pluginapi/experimental/bot/logger/logger.go new file mode 100644 index 0000000000..71e3924c1b --- /dev/null +++ b/server/public/pluginapi/experimental/bot/logger/logger.go @@ -0,0 +1,78 @@ +package logger + +import "time" + +const ( + timed = "__since" + elapsed = "Elapsed" + + ErrorKey = "error" +) + +// LogLevel defines the level of log messages +type LogLevel string + +const ( + // LogLevelDebug denotes debug messages + LogLevelDebug = "debug" + // LogLevelInfo denotes info messages + LogLevelInfo = "info" + // LogLevelWarn denotes warn messages + LogLevelWarn = "warn" + // LogLevelError denotes error messages + LogLevelError = "error" +) + +// LogContext defines the context for the logs. +type LogContext map[string]interface{} + +// Logger defines an object able to log messages. +type Logger interface { + // With adds a logContext to the logger. + With(LogContext) Logger + // WithError adds an Error to the logger. + WithError(error) Logger + // Context returns the current context + Context() LogContext + // Timed add a timed log context. + Timed() Logger + // Debugf logs a formatted string as a debug message. + Debugf(format string, args ...interface{}) + // Errorf logs a formatted string as an error message. + Errorf(format string, args ...interface{}) + // Infof logs a formatted string as an info message. + Infof(format string, args ...interface{}) + // Warnf logs a formatted string as an warning message. + Warnf(format string, args ...interface{}) +} + +func measure(lc LogContext) { + if lc[timed] == nil { + return + } + started := lc[timed].(time.Time) + lc[elapsed] = time.Since(started).String() + delete(lc, timed) +} + +// Level assigns an integer to the LogLevel string +func Level(l LogLevel) int { + switch l { + case LogLevelDebug: + return 4 + case LogLevelInfo: + return 3 + case LogLevelWarn: + return 2 + case LogLevelError: + return 1 + } + return 0 +} + +func toKeyValuePairs(in map[string]interface{}) (out []interface{}) { + for k, v := range in { + out = append(out, k, v) + } + return out +} diff --git a/server/public/pluginapi/experimental/bot/logger/nil_logger.go b/server/public/pluginapi/experimental/bot/logger/nil_logger.go new file mode 100644 index 0000000000..45beac83c2 --- /dev/null +++ b/server/public/pluginapi/experimental/bot/logger/nil_logger.go @@ -0,0 +1,17 @@ +package logger + +type nilLogger struct{} + +// NewNilLogger returns a logger that performs no action. +func NewNilLogger() Logger { + return &nilLogger{} +} + +func (l *nilLogger) With(LogContext) Logger { return l } +func (l *nilLogger) WithError(error) Logger { return l } +func (l *nilLogger) Context() LogContext { return nil } +func (l *nilLogger) Timed() Logger { return l } +func (l *nilLogger) Debugf(string, ...interface{}) {} +func (l *nilLogger) Errorf(string, ...interface{}) {} +func (l *nilLogger) Infof(string, ...interface{}) {} +func (l *nilLogger) Warnf(string, ...interface{}) {} diff --git a/server/public/pluginapi/experimental/bot/logger/telemetrylogger/telemetry_logger.go b/server/public/pluginapi/experimental/bot/logger/telemetrylogger/telemetry_logger.go new file mode 100644 index 0000000000..8f6c802c8b --- /dev/null +++ b/server/public/pluginapi/experimental/bot/logger/telemetrylogger/telemetry_logger.go @@ -0,0 +1,80 @@ +package telemetrylogger + +import ( + "fmt" + + "github.com/mattermost/mattermost/server/public/pluginapi/experimental/bot/logger" + "github.com/mattermost/mattermost/server/public/pluginapi/experimental/common" + "github.com/mattermost/mattermost/server/public/pluginapi/experimental/telemetry" +) + +type telemetryLogger struct { + logger.Logger + logLevel logger.LogLevel + tracker telemetry.Tracker +} + +/* +New promotes the provided logger into a telemetry logger, storing all events below the logLevel +through the tracker. + +- l Logger: A logger to promote. + +- logLevel: The highest type of message to be stored in telemetry. + +- tracker: The telemetry tracker to store the messages. +*/ +func New(l logger.Logger, logLevel logger.LogLevel, tracker telemetry.Tracker) logger.Logger { + return &telemetryLogger{ + Logger: l, + logLevel: logLevel, + tracker: tracker, + } +} + +// NewFromAPI creates a telemetryLogger directly from a LogAPI instead of passing a logger. +func NewFromAPI(api common.LogAPI, logLevel logger.LogLevel, tracker telemetry.Tracker) logger.Logger { + return New(logger.New(api), logLevel, tracker) +} + +func (l *telemetryLogger) Debugf(format string, args ...interface{}) { + l.Logger.Debugf(format, args...) + message := fmt.Sprintf(format, args...) + if logger.Level(l.logLevel) >= 4 { + l.logToTelemetry("DEBUG", message) + } +} + +func (l *telemetryLogger) Errorf(format string, args ...interface{}) { + l.Logger.Errorf(format, args...) + message := fmt.Sprintf(format, args...) + if logger.Level(l.logLevel) >= 1 { + l.logToTelemetry("ERROR", message) + } +} + +func (l *telemetryLogger) Infof(format string, args ...interface{}) { + l.Logger.Infof(format, args...) + message := fmt.Sprintf(format, args...) + if logger.Level(l.logLevel) >= 3 { + l.logToTelemetry("INFO", message) + } +} + +func (l *telemetryLogger) Warnf(format string, args ...interface{}) { + l.Logger.Warnf(format, args...) + message := fmt.Sprintf(format, args...) + if logger.Level(l.logLevel) >= 2 { + l.logToTelemetry("WARN", message) + } +} + +func (l *telemetryLogger) logToTelemetry(level, message string) { + properties := map[string]interface{}{} + properties["message"] = message + for k, v := range l.Context() { + properties["context_"+k] = fmt.Sprintf("%v", v) + } + + _ = l.tracker.TrackEvent("logger_"+level, properties) +} diff --git a/server/public/pluginapi/experimental/bot/logger/test_logger.go b/server/public/pluginapi/experimental/bot/logger/test_logger.go new file mode 100644 index 0000000000..afd85f12ab --- /dev/null +++ b/server/public/pluginapi/experimental/bot/logger/test_logger.go @@ -0,0 +1,61 @@ +package logger + +import ( + "fmt" + "testing" + "time" +) + +type testLogger struct { + testing.TB + logContext LogContext +} + +// NewTestLogger creates a logger for testing purposes. +func NewTestLogger() Logger { + return &testLogger{} +} + +func (l *testLogger) With(logContext LogContext) Logger { + newl := *l + if len(newl.logContext) == 0 { + newl.logContext = map[string]interface{}{} + } + for k, v := range logContext { + newl.logContext[k] = v + } + return &newl +} + +func (l *testLogger) WithError(err error) Logger { + newl := *l + if len(newl.logContext) == 0 { + newl.logContext = map[string]interface{}{} + } + newl.logContext[ErrorKey] = err.Error() + return &newl +} + +func (l *testLogger) Context() LogContext { + return l.logContext +} + +func (l *testLogger) Timed() Logger { + return l.With(LogContext{ + timed: time.Now(), + }) +} + +func (l *testLogger) logf(prefix, format string, args ...interface{}) { + out := fmt.Sprintf(prefix+": "+format, args...) + if len(l.logContext) > 0 { + measure(l.logContext) + out += fmt.Sprintf(" -- %+v", l.logContext) + } + l.TB.Logf(out) +} + +func (l *testLogger) Debugf(format string, args ...interface{}) { l.logf("DEBUG", format, args...) } +func (l *testLogger) Errorf(format string, args ...interface{}) { l.logf("ERROR", format, args...) } +func (l *testLogger) Infof(format string, args ...interface{}) { l.logf("INFO", format, args...) } +func (l *testLogger) Warnf(format string, args ...interface{}) { l.logf("WARN", format, args...) } diff --git a/server/public/pluginapi/experimental/bot/mocks/mock_bot.go b/server/public/pluginapi/experimental/bot/mocks/mock_bot.go new file mode 100644 index 0000000000..9ff7b4117c --- /dev/null +++ b/server/public/pluginapi/experimental/bot/mocks/mock_bot.go @@ -0,0 +1,77 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/mattermost/mattermost/server/public/pluginapi/experimental/bot (interfaces: Bot) + +// Package mock_bot is a generated GoMock package. +package mock_bot + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + model "github.com/mattermost/mattermost/server/public/model" +) + +// MockBot is a mock of Bot interface. +type MockBot struct { + ctrl *gomock.Controller + recorder *MockBotMockRecorder +} + +// MockBotMockRecorder is the mock recorder for MockBot. +type MockBotMockRecorder struct { + mock *MockBot +} + +// NewMockBot creates a new mock instance. +func NewMockBot(ctrl *gomock.Controller) *MockBot { + mock := &MockBot{ctrl: ctrl} + mock.recorder = &MockBotMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBot) EXPECT() *MockBotMockRecorder { + return m.recorder +} + +// Ensure mocks base method. +func (m *MockBot) Ensure(arg0 *model.Bot, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Ensure", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Ensure indicates an expected call of Ensure. +func (mr *MockBotMockRecorder) Ensure(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ensure", reflect.TypeOf((*MockBot)(nil).Ensure), arg0, arg1) +} + +// MattermostUserID mocks base method. +func (m *MockBot) MattermostUserID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MattermostUserID") + ret0, _ := ret[0].(string) + return ret0 +} + +// MattermostUserID indicates an expected call of MattermostUserID. +func (mr *MockBotMockRecorder) MattermostUserID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MattermostUserID", reflect.TypeOf((*MockBot)(nil).MattermostUserID)) +} + +// String mocks base method. +func (m *MockBot) String() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "String") + ret0, _ := ret[0].(string) + return ret0 +} + +// String indicates an expected call of String. +func (mr *MockBotMockRecorder) String() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "String", reflect.TypeOf((*MockBot)(nil).String)) +} diff --git a/server/public/pluginapi/experimental/bot/mocks/mock_logger.go b/server/public/pluginapi/experimental/bot/mocks/mock_logger.go new file mode 100644 index 0000000000..f096bdae3d --- /dev/null +++ b/server/public/pluginapi/experimental/bot/mocks/mock_logger.go @@ -0,0 +1,159 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/mattermost/mattermost/server/public/pluginapi/experimental/bot/logger (interfaces: Logger) + +// Package mock_bot is a generated GoMock package. +package mock_bot + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + logger "github.com/mattermost/mattermost/server/public/pluginapi/experimental/bot/logger" +) + +// MockLogger is a mock of Logger interface. +type MockLogger struct { + ctrl *gomock.Controller + recorder *MockLoggerMockRecorder +} + +// MockLoggerMockRecorder is the mock recorder for MockLogger. +type MockLoggerMockRecorder struct { + mock *MockLogger +} + +// NewMockLogger creates a new mock instance. +func NewMockLogger(ctrl *gomock.Controller) *MockLogger { + mock := &MockLogger{ctrl: ctrl} + mock.recorder = &MockLoggerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLogger) EXPECT() *MockLoggerMockRecorder { + return m.recorder +} + +// Context mocks base method. +func (m *MockLogger) Context() logger.LogContext { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(logger.LogContext) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockLoggerMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockLogger)(nil).Context)) +} + +// Debugf mocks base method. +func (m *MockLogger) Debugf(arg0 string, arg1 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Debugf", varargs...) +} + +// Debugf indicates an expected call of Debugf. +func (mr *MockLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockLogger)(nil).Debugf), varargs...) +} + +// Errorf mocks base method. +func (m *MockLogger) Errorf(arg0 string, arg1 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Errorf", varargs...) +} + +// Errorf indicates an expected call of Errorf. +func (mr *MockLoggerMockRecorder) Errorf(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Errorf", reflect.TypeOf((*MockLogger)(nil).Errorf), varargs...) +} + +// Infof mocks base method. +func (m *MockLogger) Infof(arg0 string, arg1 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Infof", varargs...) +} + +// Infof indicates an expected call of Infof. +func (mr *MockLoggerMockRecorder) Infof(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Infof", reflect.TypeOf((*MockLogger)(nil).Infof), varargs...) +} + +// Timed mocks base method. +func (m *MockLogger) Timed() logger.Logger { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Timed") + ret0, _ := ret[0].(logger.Logger) + return ret0 +} + +// Timed indicates an expected call of Timed. +func (mr *MockLoggerMockRecorder) Timed() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Timed", reflect.TypeOf((*MockLogger)(nil).Timed)) +} + +// Warnf mocks base method. +func (m *MockLogger) Warnf(arg0 string, arg1 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Warnf", varargs...) +} + +// Warnf indicates an expected call of Warnf. +func (mr *MockLoggerMockRecorder) Warnf(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockLogger)(nil).Warnf), varargs...) +} + +// With mocks base method. +func (m *MockLogger) With(arg0 logger.LogContext) logger.Logger { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "With", arg0) + ret0, _ := ret[0].(logger.Logger) + return ret0 +} + +// With indicates an expected call of With. +func (mr *MockLoggerMockRecorder) With(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "With", reflect.TypeOf((*MockLogger)(nil).With), arg0) +} + +// WithError mocks base method. +func (m *MockLogger) WithError(arg0 error) logger.Logger { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WithError", arg0) + ret0, _ := ret[0].(logger.Logger) + return ret0 +} + +// WithError indicates an expected call of WithError. +func (mr *MockLoggerMockRecorder) WithError(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithError", reflect.TypeOf((*MockLogger)(nil).WithError), arg0) +} diff --git a/server/public/pluginapi/experimental/bot/mocks/mock_poster.go b/server/public/pluginapi/experimental/bot/mocks/mock_poster.go new file mode 100644 index 0000000000..c826fb5ad9 --- /dev/null +++ b/server/public/pluginapi/experimental/bot/mocks/mock_poster.go @@ -0,0 +1,151 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/mattermost/mattermost/server/public/pluginapi/experimental/bot/poster (interfaces: Poster) + +// Package mock_bot is a generated GoMock package. +package mock_bot + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + model "github.com/mattermost/mattermost/server/public/model" +) + +// MockPoster is a mock of Poster interface. +type MockPoster struct { + ctrl *gomock.Controller + recorder *MockPosterMockRecorder +} + +// MockPosterMockRecorder is the mock recorder for MockPoster. +type MockPosterMockRecorder struct { + mock *MockPoster +} + +// NewMockPoster creates a new mock instance. +func NewMockPoster(ctrl *gomock.Controller) *MockPoster { + mock := &MockPoster{ctrl: ctrl} + mock.recorder = &MockPosterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPoster) EXPECT() *MockPosterMockRecorder { + return m.recorder +} + +// DM mocks base method. +func (m *MockPoster) DM(arg0, arg1 string, arg2 ...interface{}) (string, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "DM", varargs...) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DM indicates an expected call of DM. +func (mr *MockPosterMockRecorder) DM(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DM", reflect.TypeOf((*MockPoster)(nil).DM), varargs...) +} + +// DMWithAttachments mocks base method. +func (m *MockPoster) DMWithAttachments(arg0 string, arg1 ...*model.SlackAttachment) (string, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "DMWithAttachments", varargs...) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DMWithAttachments indicates an expected call of DMWithAttachments. +func (mr *MockPosterMockRecorder) DMWithAttachments(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DMWithAttachments", reflect.TypeOf((*MockPoster)(nil).DMWithAttachments), varargs...) +} + +// DeletePost mocks base method. +func (m *MockPoster) DeletePost(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeletePost", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeletePost indicates an expected call of DeletePost. +func (mr *MockPosterMockRecorder) DeletePost(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePost", reflect.TypeOf((*MockPoster)(nil).DeletePost), arg0) +} + +// Ephemeral mocks base method. +func (m *MockPoster) Ephemeral(arg0, arg1, arg2 string, arg3 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Ephemeral", varargs...) +} + +// Ephemeral indicates an expected call of Ephemeral. +func (mr *MockPosterMockRecorder) Ephemeral(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ephemeral", reflect.TypeOf((*MockPoster)(nil).Ephemeral), varargs...) +} + +// UpdatePost mocks base method. +func (m *MockPoster) UpdatePost(arg0 *model.Post) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdatePost", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdatePost indicates an expected call of UpdatePost. +func (mr *MockPosterMockRecorder) UpdatePost(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePost", reflect.TypeOf((*MockPoster)(nil).UpdatePost), arg0) +} + +// UpdatePostByID mocks base method. +func (m *MockPoster) UpdatePostByID(arg0, arg1 string, arg2 ...interface{}) error { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "UpdatePostByID", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdatePostByID indicates an expected call of UpdatePostByID. +func (mr *MockPosterMockRecorder) UpdatePostByID(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePostByID", reflect.TypeOf((*MockPoster)(nil).UpdatePostByID), varargs...) +} + +// UpdatePosterID mocks base method. +func (m *MockPoster) UpdatePosterID(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatePosterID", arg0) +} + +// UpdatePosterID indicates an expected call of UpdatePosterID. +func (mr *MockPosterMockRecorder) UpdatePosterID(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePosterID", reflect.TypeOf((*MockPoster)(nil).UpdatePosterID), arg0) +} diff --git a/server/public/pluginapi/experimental/bot/poster/default_poster.go b/server/public/pluginapi/experimental/bot/poster/default_poster.go new file mode 100644 index 0000000000..16b9f45c2b --- /dev/null +++ b/server/public/pluginapi/experimental/bot/poster/default_poster.go @@ -0,0 +1,76 @@ +package poster + +import ( + "fmt" + + "github.com/mattermost/mattermost/server/public/model" +) + +type defaultPoster struct { + postAPI PostAPI + id string +} + +// NewPoster creates a new default poster +func NewPoster(postAPI PostAPI, id string) Poster { + return &defaultPoster{ + postAPI: postAPI, + id: id, + } +} + +// DM posts a simple Direct Message to the specified user +func (p *defaultPoster) DM(mattermostUserID, format string, args ...interface{}) (string, error) { + post := &model.Post{ + Message: fmt.Sprintf(format, args...), + } + err := p.postAPI.DM(p.id, mattermostUserID, post) + if err != nil { + return "", err + } + return post.Id, nil +} + +// DMWithAttachments posts a Direct Message that contains Slack attachments. +// Often used to include post actions. +func (p *defaultPoster) DMWithAttachments(mattermostUserID string, attachments ...*model.SlackAttachment) (string, error) { + post := model.Post{} + model.ParseSlackAttachment(&post, attachments) + err := p.postAPI.DM(p.id, mattermostUserID, &post) + if err != nil { + return "", err + } + return post.Id, nil +} + +// Ephemeral sends an ephemeral message to a user +func (p *defaultPoster) Ephemeral(userID, channelID, format string, args ...interface{}) { + post := &model.Post{ + UserId: p.id, + ChannelId: channelID, + Message: fmt.Sprintf(format, args...), + } + p.postAPI.SendEphemeralPost(userID, post) +} + +func (p *defaultPoster) UpdatePostByID(postID, format string, args ...interface{}) error { + post, err := p.postAPI.GetPost(postID) + if err != nil { + return err + } + + post.Message = fmt.Sprintf(format, args...) + return p.UpdatePost(post) +} + +func (p *defaultPoster) DeletePost(postID string) error { + return p.postAPI.DeletePost(postID) +} + +func (p *defaultPoster) UpdatePost(post *model.Post) error { + return p.postAPI.UpdatePost(post) +} + +func (p *defaultPoster) UpdatePosterID(id string) { + p.id = id +} diff --git a/server/public/pluginapi/experimental/bot/poster/default_poster_test.go b/server/public/pluginapi/experimental/bot/poster/default_poster_test.go new file mode 100644 index 0000000000..2770b13f88 --- /dev/null +++ b/server/public/pluginapi/experimental/bot/poster/default_poster_test.go @@ -0,0 +1,414 @@ +package poster + +import ( + "errors" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/mattermost/mattermost/server/public/pluginapi" + "github.com/mattermost/mattermost/server/public/pluginapi/experimental/bot/poster/mock_import" +) + +const ( + botID = "test-bot-user" + userID = "test-user-1" + dmChannelID = "dm-channel-id" +) + +func TestInterface(t *testing.T) { + t.Run("Plugin API satisfy the interface", func(t *testing.T) { + api := &plugintest.API{} + driver := &plugintest.Driver{} + client := pluginapi.NewClient(api, driver) + _ = NewPoster(&client.Post, botID) + }) +} + +func TestDM(t *testing.T) { + format := "test format, string: %s int: %d value: %v" + args := []interface{}{"some string", 5, 8.423} + expectedMessage := "test format, string: some string int: 5 value: 8.423" + + expectedPostID := "expected-post-id" + + post := &model.Post{ + Message: expectedMessage, + } + + postWithID := model.Post{ + Id: expectedPostID, + UserId: botID, + ChannelId: dmChannelID, + Message: expectedMessage, + } + + mockError := errors.New("mock error") + + t.Run("DM Success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + postAPI := mock_import.NewMockPostAPI(ctrl) + + poster := NewPoster(postAPI, botID) + + //nolint:govet //copy lock, but only used in tests + postAPI. + EXPECT(). + DM(botID, userID, post). + SetArg(2, postWithID). + Return(nil). + Times(1) + + postID, err := poster.DM(userID, format, args...) + assert.Equal(t, expectedPostID, postID) + assert.NoError(t, err) + }) + + t.Run("DM error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + postAPI := mock_import.NewMockPostAPI(ctrl) + + poster := NewPoster(postAPI, botID) + + postAPI. + EXPECT(). + DM(botID, userID, post). + Return(mockError). + Times(1) + + _, err := poster.DM(userID, format, args...) + assert.Error(t, err) + }) +} + +func TestDMWithAttachments(t *testing.T) { + expectedPostID := "expected-post-id" + + attachments := []*model.SlackAttachment{ + {}, + {}, + } + + post := &model.Post{} + + model.ParseSlackAttachment(post, attachments) + + postWithID := model.Post{ + Id: expectedPostID, + UserId: botID, + ChannelId: dmChannelID, + Type: model.PostTypeSlackAttachment, + Props: model.StringInterface{ + "attachments": attachments, + }, + } + + mockError := errors.New("mock error") + t.Run("DM Success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + postAPI := mock_import.NewMockPostAPI(ctrl) + + poster := NewPoster(postAPI, botID) + + //nolint:govet //copy lock, but only used in tests + postAPI. + EXPECT(). + DM(botID, userID, post). + SetArg(2, postWithID). + Return(nil). + Times(1) + + postID, err := poster.DMWithAttachments(userID, attachments...) + assert.Equal(t, expectedPostID, postID) + assert.NoError(t, err) + }) + + t.Run("DM error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + postAPI := mock_import.NewMockPostAPI(ctrl) + + poster := NewPoster(postAPI, botID) + + postAPI. + EXPECT(). + DM(botID, userID, post). + Return(mockError). + Times(1) + + _, err := poster.DMWithAttachments(userID, attachments...) + assert.Error(t, err) + }) +} + +func TestEphemeral(t *testing.T) { + format := "test format, string: %s int: %d value: %v" + args := []interface{}{"some string", 5, 8.423} + expectedMessage := "test format, string: some string int: 5 value: 8.423" + + channelID := "some-channel" + + post := &model.Post{ + UserId: botID, + ChannelId: channelID, + Message: expectedMessage, + } + + expectedPostID := "some-post-ID" + + postWithID := model.Post{ + Id: expectedPostID, + UserId: botID, + ChannelId: channelID, + Message: expectedMessage, + } + + t.Run("Success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + postAPI := mock_import.NewMockPostAPI(ctrl) + + poster := NewPoster(postAPI, botID) + + //nolint:govet //copy lock, but only used in tests + postAPI. + EXPECT(). + SendEphemeralPost(userID, post). + SetArg(1, postWithID). + Times(1) + + poster.Ephemeral(userID, channelID, format, args...) + }) +} + +func TestUpdatePostByID(t *testing.T) { + format := "test format, string: %s int: %d value: %v" + args := []interface{}{"some string", 5, 8.423} + expectedMessage := "test format, string: some string int: 5 value: 8.423" + + postID := "some-post-id" + originalPost := &model.Post{ + Id: postID, + Message: "some message", + } + + updatedPost := &model.Post{ + Id: postID, + Message: expectedMessage, + } + + mockError := errors.New("mock error") + + t.Run("Success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + postAPI := mock_import.NewMockPostAPI(ctrl) + + poster := NewPoster(postAPI, botID) + + postAPI. + EXPECT(). + GetPost(postID). + Return(originalPost, nil). + Times(1) + + postAPI. + EXPECT(). + UpdatePost(updatedPost). + Return(nil). + Times(1) + + err := poster.UpdatePostByID(postID, format, args...) + assert.NoError(t, err) + }) + + t.Run("Error fetching", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + postAPI := mock_import.NewMockPostAPI(ctrl) + + poster := NewPoster(postAPI, botID) + + postAPI. + EXPECT(). + GetPost(postID). + Return(nil, mockError). + Times(1) + + err := poster.UpdatePostByID(postID, format, args...) + assert.Error(t, err) + }) + + t.Run("Error updating", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + postAPI := mock_import.NewMockPostAPI(ctrl) + + poster := NewPoster(postAPI, botID) + + postAPI. + EXPECT(). + GetPost(postID). + Return(originalPost, nil). + Times(1) + + postAPI. + EXPECT(). + UpdatePost(updatedPost). + Return(mockError). + Times(1) + + err := poster.UpdatePostByID(postID, format, args...) + assert.Error(t, err) + }) +} + +func TestDeletePost(t *testing.T) { + postID := "some-post-id" + + mockError := errors.New("mock channel error") + t.Run("Success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + postAPI := mock_import.NewMockPostAPI(ctrl) + + poster := NewPoster(postAPI, botID) + + postAPI. + EXPECT(). + DeletePost(postID). + Return(nil). + Times(1) + + err := poster.DeletePost(postID) + assert.NoError(t, err) + }) + + t.Run("Error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + postAPI := mock_import.NewMockPostAPI(ctrl) + + poster := NewPoster(postAPI, botID) + + postAPI. + EXPECT(). + DeletePost(postID). + Return(mockError). + Times(1) + + err := poster.DeletePost(postID) + assert.Error(t, err) + }) +} + +func TestUpdatePost(t *testing.T) { + post := &model.Post{ + Id: "some-post-id", + Message: "some message", + } + + mockError := errors.New("mock channel error") + t.Run("Success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + postAPI := mock_import.NewMockPostAPI(ctrl) + + poster := NewPoster(postAPI, botID) + + postAPI. + EXPECT(). + UpdatePost(post). + Return(nil). + Times(1) + + err := poster.UpdatePost(post) + assert.NoError(t, err) + }) + + t.Run("Error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + postAPI := mock_import.NewMockPostAPI(ctrl) + + poster := NewPoster(postAPI, botID) + + postAPI. + EXPECT(). + UpdatePost(post). + Return(mockError). + Times(1) + + err := poster.UpdatePost(post) + assert.Error(t, err) + }) +} + +func TestUpdatePosterID(t *testing.T) { + format := "test format, string: %s int: %d value: %v" + args := []interface{}{"some string", 5, 8.423} + expectedMessage := "test format, string: some string int: 5 value: 8.423" + + expectedPostID := "expected-post-id" + + post := &model.Post{ + Message: expectedMessage, + } + + postWithID := model.Post{ + Id: expectedPostID, + UserId: botID, + ChannelId: dmChannelID, + Message: expectedMessage, + } + + newBotID := "new-bot-id" + + t.Run("Success", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + postAPI := mock_import.NewMockPostAPI(ctrl) + + poster := NewPoster(postAPI, botID) + + //nolint:govet //copy lock, but only used in tests + postAPI. + EXPECT(). + DM(botID, userID, post). + SetArg(2, postWithID). + Return(nil). + Times(1) + + _, _ = poster.DM(userID, format, args...) + poster.UpdatePosterID(newBotID) + + //nolint:govet //copy lock, but only used in tests + postAPI. + EXPECT(). + DM(newBotID, userID, post). + SetArg(2, postWithID). + Return(nil). + Times(1) + + _, _ = poster.DM(userID, format, args...) + }) +} diff --git a/server/public/pluginapi/experimental/bot/poster/import.go b/server/public/pluginapi/experimental/bot/poster/import.go new file mode 100644 index 0000000000..f3c532d970 --- /dev/null +++ b/server/public/pluginapi/experimental/bot/poster/import.go @@ -0,0 +1,14 @@ +package poster + +import ( + "github.com/mattermost/mattermost/server/public/model" +) + +// PostAPI defines the portion of the Post Service used by the poster +type PostAPI interface { + DM(senderUserID, receiverUserID string, post *model.Post) error + GetPost(postID string) (*model.Post, error) + UpdatePost(post *model.Post) error + DeletePost(postID string) error + SendEphemeralPost(userID string, post *model.Post) +} diff --git a/server/public/pluginapi/experimental/bot/poster/mock_import/mock_postapi.go b/server/public/pluginapi/experimental/bot/poster/mock_import/mock_postapi.go new file mode 100644 index 0000000000..7fe3eb21c6 --- /dev/null +++ b/server/public/pluginapi/experimental/bot/poster/mock_import/mock_postapi.go @@ -0,0 +1,104 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/mattermost/mattermost/server/public/pluginapi/experimental/bot/poster (interfaces: PostAPI) + +// Package mock_import is a generated GoMock package. +package mock_import + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + model "github.com/mattermost/mattermost/server/public/model" +) + +// MockPostAPI is a mock of PostAPI interface. +type MockPostAPI struct { + ctrl *gomock.Controller + recorder *MockPostAPIMockRecorder +} + +// MockPostAPIMockRecorder is the mock recorder for MockPostAPI. +type MockPostAPIMockRecorder struct { + mock *MockPostAPI +} + +// NewMockPostAPI creates a new mock instance. +func NewMockPostAPI(ctrl *gomock.Controller) *MockPostAPI { + mock := &MockPostAPI{ctrl: ctrl} + mock.recorder = &MockPostAPIMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPostAPI) EXPECT() *MockPostAPIMockRecorder { + return m.recorder +} + +// DM mocks base method. +func (m *MockPostAPI) DM(arg0, arg1 string, arg2 *model.Post) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DM", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// DM indicates an expected call of DM. +func (mr *MockPostAPIMockRecorder) DM(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DM", reflect.TypeOf((*MockPostAPI)(nil).DM), arg0, arg1, arg2) +} + +// DeletePost mocks base method. +func (m *MockPostAPI) DeletePost(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeletePost", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeletePost indicates an expected call of DeletePost. +func (mr *MockPostAPIMockRecorder) DeletePost(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePost", reflect.TypeOf((*MockPostAPI)(nil).DeletePost), arg0) +} + +// GetPost mocks base method. +func (m *MockPostAPI) GetPost(arg0 string) (*model.Post, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPost", arg0) + ret0, _ := ret[0].(*model.Post) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPost indicates an expected call of GetPost. +func (mr *MockPostAPIMockRecorder) GetPost(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPost", reflect.TypeOf((*MockPostAPI)(nil).GetPost), arg0) +} + +// SendEphemeralPost mocks base method. +func (m *MockPostAPI) SendEphemeralPost(arg0 string, arg1 *model.Post) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SendEphemeralPost", arg0, arg1) +} + +// SendEphemeralPost indicates an expected call of SendEphemeralPost. +func (mr *MockPostAPIMockRecorder) SendEphemeralPost(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendEphemeralPost", reflect.TypeOf((*MockPostAPI)(nil).SendEphemeralPost), arg0, arg1) +} + +// UpdatePost mocks base method. +func (m *MockPostAPI) UpdatePost(arg0 *model.Post) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdatePost", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdatePost indicates an expected call of UpdatePost. +func (mr *MockPostAPIMockRecorder) UpdatePost(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePost", reflect.TypeOf((*MockPostAPI)(nil).UpdatePost), arg0) +} diff --git a/server/public/pluginapi/experimental/bot/poster/poster.go b/server/public/pluginapi/experimental/bot/poster/poster.go new file mode 100644 index 0000000000..2fe032a7a3 --- /dev/null +++ b/server/public/pluginapi/experimental/bot/poster/poster.go @@ -0,0 +1,38 @@ +// Copyright (c) 2019-present Mattermost, Inc. All Rights Reserved. +// See License for license information. + +package poster + +import ( + "github.com/mattermost/mattermost/server/public/model" +) + +// Poster defines an entity that can post DMs and Ephemerals and update and delete those posts +type Poster interface { + DMer + + // DMWithAttachments posts a Direct Message that contains Slack attachments. + // Often used to include post actions. + DMWithAttachments(mattermostUserID string, attachments ...*model.SlackAttachment) (string, error) + + // Ephemeral sends an ephemeral message to a user + Ephemeral(mattermostUserID, channelID, format string, args ...interface{}) + + // UpdatePostByID updates the post with postID with the formatted message + UpdatePostByID(postID, format string, args ...interface{}) error + + // DeletePost deletes a single post + DeletePost(postID string) error + + // DMUpdatePost substitute one post with another + UpdatePost(post *model.Post) error + + // UpdatePosterID updates the Mattermost User ID of the poster + UpdatePosterID(id string) +} + +// DMer defines an entity that can send Direct Messages +type DMer interface { + // DM posts a simple Direct Message to the specified user + DM(mattermostUserID, format string, args ...interface{}) (string, error) +} diff --git a/server/public/pluginapi/experimental/command/command.go b/server/public/pluginapi/experimental/command/command.go new file mode 100644 index 0000000000..da8b0ff0bd --- /dev/null +++ b/server/public/pluginapi/experimental/command/command.go @@ -0,0 +1,31 @@ +package command + +import ( + "encoding/base64" + "fmt" + "os" + "path/filepath" + + "github.com/pkg/errors" +) + +// PluginAPI is the plugin API interface required to manage slash commands. +type PluginAPI interface { + GetBundlePath() (string, error) +} + +// GetIconData returns the base64 encoding of a icon for a given path. +// The data returned may be used for slash command autocomplete. +func GetIconData(api PluginAPI, iconPath string) (string, error) { + bundlePath, err := api.GetBundlePath() + if err != nil { + return "", errors.Wrap(err, "couldn't get bundle path") + } + + icon, err := os.ReadFile(filepath.Join(bundlePath, iconPath)) + if err != nil { + return "", errors.Wrap(err, "failed to open icon") + } + + return fmt.Sprintf("data:image/svg+xml;base64,%s", base64.StdEncoding.EncodeToString(icon)), nil +} diff --git a/server/public/pluginapi/experimental/command/info.go b/server/public/pluginapi/experimental/command/info.go new file mode 100644 index 0000000000..807107f2d2 --- /dev/null +++ b/server/public/pluginapi/experimental/command/info.go @@ -0,0 +1,74 @@ +package command + +import ( + "fmt" + "regexp" + "runtime/debug" + "strings" + "time" + + "github.com/pkg/errors" + + "github.com/mattermost/mattermost/server/public/model" +) + +var versionRegexp = regexp.MustCompile(`/v\d$`) + +func BuildInfoAutocomplete(cmd string) *model.AutocompleteData { + return model.NewAutocompleteData(cmd, "", "Display build info") +} + +func BuildInfo(manifest model.Manifest) (string, error) { + info, ok := debug.ReadBuildInfo() + if !ok { + return "", errors.New("failed to read build info") + } + + var ( + revision string + revisionShort string + buildTime time.Time + dirty bool + ) + for _, s := range info.Settings { + switch s.Key { + case "vcs.revision": + revision = s.Value + revisionShort = revision[0:7] + case "vcs.time": + var err error + buildTime, err = time.Parse(time.RFC3339, s.Value) + + if err != nil { + return "", err + } + case "vcs.modified": + if s.Value == "true" { + dirty = true + } + } + } + + path := info.Main.Path + + matches := versionRegexp.FindAllString(path, -1) + if len(matches) > 0 { + path = strings.TrimSuffix(path, matches[len(matches)-1]) + } + + dirtyText := "" + if dirty { + dirtyText = " (dirty)" + } + + commit := fmt.Sprintf("[%s](https://%s/commit/%s)", revisionShort, path, revision) + + return fmt.Sprintf("%s version: %s, %s%s, built %s with %s\n", + manifest.Name, + manifest.Version, + commit, + dirtyText, + buildTime.Format(time.RFC1123), + info.GoVersion), + nil +} diff --git a/server/public/pluginapi/experimental/common/kvstore.go b/server/public/pluginapi/experimental/common/kvstore.go new file mode 100644 index 0000000000..7e797cffdb --- /dev/null +++ b/server/public/pluginapi/experimental/common/kvstore.go @@ -0,0 +1,21 @@ +package common + +import ( + "errors" + "time" + + "github.com/mattermost/mattermost/server/public/pluginapi" +) + +var ErrNotFound = errors.New("not found") + +type KVStore interface { + Set(key string, value interface{}, options ...pluginapi.KVSetOption) (bool, error) + SetWithExpiry(key string, value interface{}, ttl time.Duration) error + CompareAndSet(key string, oldValue, value interface{}) (bool, error) + CompareAndDelete(key string, oldValue interface{}) (bool, error) + Get(key string, o interface{}) error + Delete(key string) error + DeleteAll() error + ListKeys(page, count int, options ...pluginapi.ListKeysOption) ([]string, error) +} diff --git a/server/public/pluginapi/experimental/common/logapi.go b/server/public/pluginapi/experimental/common/logapi.go new file mode 100644 index 0000000000..8613ef59a0 --- /dev/null +++ b/server/public/pluginapi/experimental/common/logapi.go @@ -0,0 +1,8 @@ +package common + +type LogAPI interface { + LogError(message string, keyValuePairs ...interface{}) + LogWarn(message string, keyValuePairs ...interface{}) + LogInfo(message string, keyValuePairs ...interface{}) + LogDebug(message string, keyValuePairs ...interface{}) +} diff --git a/server/public/pluginapi/experimental/common/markdown.go b/server/public/pluginapi/experimental/common/markdown.go new file mode 100644 index 0000000000..34d418dd97 --- /dev/null +++ b/server/public/pluginapi/experimental/common/markdown.go @@ -0,0 +1,22 @@ +// Copyright (c) 2019-present Mattermost, Inc. All Rights Reserved. +// See License for license information. + +package common + +import ( + "encoding/json" + "fmt" +) + +func JSON(ref interface{}) string { + bb, _ := json.MarshalIndent(ref, "", " ") + return string(bb) +} + +func CodeBlock(in string) string { + return fmt.Sprintf("\n```\n%s\n```\n", in) +} + +func JSONBlock(ref interface{}) string { + return fmt.Sprintf("\n```json\n%s\n```\n", JSON(ref)) +} diff --git a/server/public/pluginapi/experimental/common/slack_attachments.go b/server/public/pluginapi/experimental/common/slack_attachments.go new file mode 100644 index 0000000000..52853c6216 --- /dev/null +++ b/server/public/pluginapi/experimental/common/slack_attachments.go @@ -0,0 +1,24 @@ +package common + +import ( + "encoding/json" + "net/http" + + "github.com/mattermost/mattermost/server/public/model" +) + +func SlackAttachmentError(w http.ResponseWriter, err error) { + response := model.PostActionIntegrationResponse{ + EphemeralText: "Error:" + err.Error(), + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) +} + +func DialogError(w http.ResponseWriter, err error) { + response := model.SubmitDialogResponse{ + Error: err.Error(), + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) +} diff --git a/server/public/pluginapi/experimental/common/url.go b/server/public/pluginapi/experimental/common/url.go new file mode 100644 index 0000000000..2faa978a83 --- /dev/null +++ b/server/public/pluginapi/experimental/common/url.go @@ -0,0 +1,28 @@ +package common + +import ( + "net/url" + "strings" + + "github.com/mattermost/mattermost/server/public/pluginapi" +) + +// GetPluginURL returns a url like siteURL/plugins/pluginID based on the information from the client. +// If any error happens in the process, a empty string is returned. +func GetPluginURL(client *pluginapi.Client) string { + mattermostSiteURL := client.Configuration.GetConfig().ServiceSettings.SiteURL + if mattermostSiteURL == nil { + return "" + } + _, err := url.Parse(*mattermostSiteURL) + if err != nil { + return "" + } + manifest, err := client.System.GetManifest() + if err != nil { + return "" + } + + pluginURLPath := "/plugins/" + manifest.Id + return strings.TrimRight(*mattermostSiteURL, "/") + pluginURLPath +} diff --git a/server/public/pluginapi/experimental/flow/flow.go b/server/public/pluginapi/experimental/flow/flow.go new file mode 100644 index 0000000000..036b1e6099 --- /dev/null +++ b/server/public/pluginapi/experimental/flow/flow.go @@ -0,0 +1,254 @@ +package flow + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/gorilla/mux" + "github.com/pkg/errors" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/pluginapi" +) + +type Name string + +const ( + contextStepKey = "step" + contextButtonKey = "button" +) + +type Flow struct { + UserID string + state *flowState + + name Name + api *pluginapi.Client + pluginURL string + botUserID string + + steps map[Name]Step + index []Name + done func(userID string, state State) error + + debugLogState bool +} + +// NewFlow creates a new flow using direct messages with the user. +// +// name must be a unique identifier for the flow within the plugin. +func NewFlow(name Name, api *pluginapi.Client, pluginURL, botUserID string) *Flow { + return &Flow{ + name: name, + api: api, + pluginURL: pluginURL, + botUserID: botUserID, + steps: map[Name]Step{}, + } +} + +func (f *Flow) WithSteps(orderedSteps ...Step) *Flow { + if f.steps == nil { + f.steps = map[Name]Step{} + } + for _, step := range orderedSteps { + stepName := step.name + if _, ok := f.steps[stepName]; ok { + f.api.Log.Warn("ignored duplicate step name", "name", stepName, "flow", f.name) + continue + } + f.steps[stepName] = step + f.index = append(f.index, stepName) + } + return f +} + +func (f *Flow) OnDone(done func(string, State) error) *Flow { + f.done = done + return f +} + +func (f *Flow) InitHTTP(r *mux.Router) *Flow { + flowRouter := r.PathPrefix("/").Subrouter() + flowRouter.HandleFunc(namePath(f.name)+"/button", f.handleButtonHTTP).Methods(http.MethodPost) + flowRouter.HandleFunc(namePath(f.name)+"/dialog", f.handleDialogHTTP).Methods(http.MethodPost) + return f +} + +func (f *Flow) WithDebugLog() *Flow { + f.debugLogState = true + return f +} + +// ForUser creates a new flow using direct messages with the user. +func (f *Flow) ForUser(userID string) *Flow { + clone := *f + clone.UserID = userID + clone.state = nil + return &clone +} + +func (f *Flow) GetCurrentStep() (Name, error) { + state, err := f.getState() + if err != nil { + // Don't return an error if no flow is running + if errors.Is(err, errStateNotFound) { + return "", nil + } + + return "", err + } + + return state.StepName, err +} + +func (f *Flow) GetState() State { + state, _ := f.getState() + return state.AppState +} + +func (f *Flow) Start(appState State) error { + if len(f.index) == 0 { + return errors.New("no steps") + } + + err := f.storeState(flowState{ + AppState: appState, + }) + if err != nil { + return err + } + + return f.Go(f.index[0]) +} + +func (f *Flow) Finish() error { + state, err := f.getState() + if err != nil { + return err + } + + _ = f.removeState() + + if f.done != nil { + err = f.done(f.UserID, state.AppState) + } + return err +} + +func (f *Flow) Go(toName Name) error { + state, err := f.getState() + if err != nil { + return err + } + if toName == state.StepName { + // Stay at the current step, nothing to do + return nil + } + // Moving onto a different step, mark the current step as "Done" + if state.StepName != "" && !state.Done { + from, ok := f.steps[state.StepName] + if !ok { + return errors.Errorf("%s: step not found", toName) + } + + var donePost *model.Post + donePost, err = from.done(f, 0) + if err != nil { + return err + } + if donePost != nil { + donePost.Id = state.PostID + err = f.api.Post.UpdatePost(donePost) + if err != nil { + return err + } + } + } + + if toName == "" { + return f.Finish() + } + to, ok := f.steps[toName] + if !ok { + return errors.Errorf("%s: step not found", toName) + } + + post, terminal, err := to.do(f) + if err != nil { + return err + } + f.processButtonPostActions(post) + + if f.debugLogState { + data, _ := json.MarshalIndent(state, "", " ") + post.Message = fmt.Sprintf("State:\n```\n%s\n```\n", string(data)) + } + + err = f.api.Post.DM(f.botUserID, f.UserID, post) + if err != nil { + return err + } + if terminal { + return f.Finish() + } + + state.StepName = toName + state.Done = false + state.PostID = post.Id + err = f.storeState(state) + if err != nil { + return err + } + + if to.autoForward { + var nextName Name + + if to.forwardTo != "" { + nextName = to.forwardTo + } else { + nextName = f.next(toName) + } + + if nextName != "" { + return f.Go(nextName) + } + } + + return nil +} + +func (f Flow) next(fromName Name) Name { + for i, n := range f.index { + if fromName == n { + if i+1 < len(f.index) { + return f.index[i+1] + } + return "" + } + } + return "" +} + +func namePath(name Name) string { + return "/" + url.PathEscape(strings.Trim(string(name), "/")) +} + +func Goto(toName Name) func(*Flow) (Name, State, error) { + return func(_ *Flow) (Name, State, error) { + return toName, nil, nil + } +} + +func DialogGoto(toName Name) func(*Flow, map[string]interface{}) (Name, State, map[string]string, error) { + return func(_ *Flow, submitted map[string]interface{}) (Name, State, map[string]string, error) { + stateUpdate := State{} + for k, v := range submitted { + stateUpdate[k] = fmt.Sprintf("%v", v) + } + return toName, stateUpdate, nil, nil + } +} diff --git a/server/public/pluginapi/experimental/flow/handler.go b/server/public/pluginapi/experimental/flow/handler.go new file mode 100644 index 0000000000..5d5e4dac01 --- /dev/null +++ b/server/public/pluginapi/experimental/flow/handler.go @@ -0,0 +1,210 @@ +// Copyright (c) 2019-present Mattermost, Inc. All Rights Reserved. +// See License for license information. + +package flow + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/pkg/errors" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/pluginapi/experimental/common" +) + +func (f *Flow) handleButtonHTTP(w http.ResponseWriter, r *http.Request) { + userID := r.Header.Get("Mattermost-User-ID") + if userID == "" { + common.SlackAttachmentError(w, errors.New("Not authorized")) + return + } + f = f.ForUser(userID) + + var request model.PostActionIntegrationRequest + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + common.SlackAttachmentError(w, errors.New("invalid request")) + return + } + + // selectedButton is 1-based + fromName, selectedButton, err := buttonContext(&request) + if err != nil { + common.SlackAttachmentError(w, err) + return + } + + donePost, err := f.handleButton(fromName, selectedButton, request.TriggerId) + if err != nil { + common.SlackAttachmentError(w, err) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(model.PostActionIntegrationResponse{ + Update: donePost, + }) +} + +func (f *Flow) handleDialogHTTP(w http.ResponseWriter, r *http.Request) { + userID := r.Header.Get("Mattermost-User-ID") + if userID == "" { + common.DialogError(w, errors.New("not authorized")) + return + } + f = f.ForUser(userID) + + var request model.SubmitDialogRequest + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + common.DialogError(w, errors.New("invalid request")) + return + } + fromName, selectedButton, err := dialogContext(&request) + if err != nil { + common.DialogError(w, errors.Wrap(err, "invalid request")) + return + } + + // handleDialog updates the post + donePost, fieldErrors, err := f.handleDialog(fromName, selectedButton, request.Submission) + if err != nil || len(fieldErrors) != 0 { + w.Header().Set("Content-Type", "application/json") + + resp := model.SubmitDialogResponse{ + Errors: fieldErrors, + } + + if err != nil { + resp.Error = err.Error() + } + + _ = json.NewEncoder(w).Encode(resp) + return + } + err = f.api.Post.UpdatePost(donePost) + if err != nil { + common.DialogError(w, err) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(model.SubmitDialogResponse{}) +} + +func (f *Flow) handleButton(fromName Name, selectedButton int, triggerID string) (*model.Post, error) { + post, _, err := f.handle(fromName, selectedButton, nil, triggerID, true) + return post, err +} + +func (f *Flow) handleDialog( + fromName Name, selectedButton int, submission map[string]interface{}, +) ( + *model.Post, map[string]string, error, +) { + return f.handle(fromName, selectedButton, submission, "", false) +} + +func (f *Flow) handle( + fromName Name, selectedButton int, submission map[string]interface{}, triggerID string, asButton bool, +) ( + *model.Post, map[string]string, error, +) { + state, err := f.getState() + if err != nil { + return nil, nil, err + } + if state.StepName != fromName { + return nil, nil, errors.Errorf("click from an inactive step: %v", fromName) + } + from, ok := f.steps[fromName] + if !ok { + return nil, nil, errors.Errorf("step %q not found", fromName) + } + + if selectedButton == 0 || selectedButton > len(from.buttons) { + return nil, nil, errors.Errorf("button number %v to high or too low, only %v buttons", selectedButton, len(from.buttons)) + } + b := from.buttons[selectedButton-1] + + var updated State + toName := fromName + var fieldErrors map[string]string + if asButton { + if b.OnClick != nil { + toName, updated, err = b.OnClick(f) + } + } else { + if b.OnDialogSubmit != nil { + toName, updated, fieldErrors, err = b.OnDialogSubmit(f, submission) + } + } + if err != nil || len(fieldErrors) > 0 { + return nil, fieldErrors, err + } + state.AppState = state.AppState.MergeWith(updated) + state.Done = true + err = f.storeState(state) + if err != nil { + return nil, nil, err + } + + // Empty next step name in the response indicates advancing to the next step + // in the flow. To stay on the same step the handlers should return the step + // name. + if toName == "" { + toName = f.next(fromName) + } + + if asButton && b.Dialog != nil { + if b.OnDialogSubmit == nil { + return nil, nil, errors.Errorf("no submit function for dialog, step: %s", fromName) + } + + dialogRequest := model.OpenDialogRequest{ + TriggerId: triggerID, + URL: f.pluginURL + namePath(f.name) + "/dialog", + Dialog: processDialog(b.Dialog, state.AppState), + } + dialogRequest.Dialog.State = fmt.Sprintf("%v,%v", fromName, selectedButton) + + err = f.api.Frontend.OpenInteractiveDialog(dialogRequest) + if err != nil { + return nil, nil, err + } + } + + if toName == fromName { + // Nothing else to do + return nil, nil, nil + } + + donePost, err := from.done(f, selectedButton) + if err != nil { + return nil, nil, err + } + donePost.Id = state.PostID + f.processButtonPostActions(donePost) + + err = f.Go(toName) + if err != nil { + f.api.Log.Warn("failed to advance flow to next step", "flow_name", f.name, "from", fromName, "to", toName, "error", err.Error()) + } + + // return the "done" post for the from step - leave updating up to the + // API-specific caller. + return donePost, nil, nil +} + +func (f *Flow) processButtonPostActions(post *model.Post) { + attachments, ok := post.GetProp("attachments").([]*model.SlackAttachment) + if !ok || len(attachments) == 0 { + return + } + sa := attachments[0] + for _, a := range sa.Actions { + if a.Integration == nil { + a.Integration = &model.PostActionIntegration{} + } + a.Integration.URL = f.pluginURL + namePath(f.name) + "/button" + } +} diff --git a/server/public/pluginapi/experimental/flow/state.go b/server/public/pluginapi/experimental/flow/state.go new file mode 100644 index 0000000000..e9393dc3c8 --- /dev/null +++ b/server/public/pluginapi/experimental/flow/state.go @@ -0,0 +1,146 @@ +package flow + +import ( + "bytes" + "errors" + "text/template" +) + +var errStateNotFound = errors.New("flow state not found") + +// State is the "app"'s state +type State map[string]interface{} + +func (s State) MergeWith(update State) State { + n := State{} + for k, v := range s { + n[k] = v + } + for k, v := range update { + n[k] = v + } + return n +} + +// GetString return the value to a given key as a string. +// If the key is not found or isn't a string, an empty string is returned. +func (s State) GetString(key string) string { + vRaw, ok := s[key] + if ok { + v, ok := vRaw.(string) + if ok { + return v + } + } + + return "" +} + +// GetInt return the value to a given key as a int. +// If the key is not found or isn't an int, zero is returned. +func (s State) GetInt(key string) int { + vRaw, ok := s[key] + if ok { + v, ok := vRaw.(int) + if ok { + return v + } + } + + return 0 +} + +// GetBool return the value to a given key as a bool. +// If the key is not found or isn't a bool, false is returned. +func (s State) GetBool(key string) bool { + vRaw, ok := s[key] + if ok { + v, ok := vRaw.(bool) + if ok { + return v + } + } + + return false +} + +// JSON-serializable flow state. +type flowState struct { + // The name of the step. + StepName Name + + Done bool + + // ID of the post produced by the step. + PostID string + + // Application-level state. + AppState State +} + +func (f *Flow) storeState(state flowState) error { + if f.UserID == "" { + return errors.New("no user specified") + } + + // Set AppState to differentiate an existing flow + if state.AppState == nil { + state.AppState = State{} + } + + ok, err := f.api.KV.Set(kvKey(f.UserID, f.name), state) + if err != nil { + return err + } + if !ok { + return errors.New("value not set without errors") + } + + f.state = &state + return nil +} + +func (f *Flow) getState() (flowState, error) { + if f.UserID == "" { + return flowState{}, errors.New("no user specified") + } + if f.state != nil { + return *f.state, nil + } + state := flowState{} + err := f.api.KV.Get(kvKey(f.UserID, f.name), &state) + if err != nil { + return flowState{}, err + } + if state.AppState == nil { + return flowState{}, errStateNotFound + } + + f.state = &state + return state, err +} + +func (f *Flow) removeState() error { + if f.UserID == "" { + return errors.New("no user specified") + } + f.state = nil + return f.api.KV.Delete(kvKey(f.UserID, f.name)) +} + +func kvKey(userID string, flowName Name) string { + return "_flow-" + userID + "-" + string(flowName) +} + +func formatState(source string, state State) string { + t, err := template.New("message").Parse(source) + if err != nil { + return source + " ###ERROR: " + err.Error() + } + buf := bytes.NewBuffer(nil) + err = t.Execute(buf, state) + if err != nil { + return source + " ###ERROR: " + err.Error() + } + return buf.String() +} diff --git a/server/public/pluginapi/experimental/flow/step.go b/server/public/pluginapi/experimental/flow/step.go new file mode 100644 index 0000000000..c4b5e829c8 --- /dev/null +++ b/server/public/pluginapi/experimental/flow/step.go @@ -0,0 +1,269 @@ +package flow + +import ( + "fmt" + "net/url" + "strconv" + "strings" + + "github.com/pkg/errors" + + "github.com/mattermost/mattermost/server/public/model" +) + +type Color string + +const ( + ColorDefault Color = "default" + ColorPrimary Color = "primary" + ColorSuccess Color = "success" + ColorGood Color = "good" + ColorWarning Color = "warning" + ColorDanger Color = "danger" +) + +type Step struct { + name Name + template *model.SlackAttachment + forwardTo Name + autoForward bool + terminal bool + onRender func(f *Flow) + buttons []Button +} + +type Button struct { + Name string + Disabled bool + Color Color + + // OnClick is called when the button is clicked. It returns the next step's + // name and the state updates to apply. + // + // If Dialog is also specified, OnClick is executed first. + OnClick func(f *Flow) (Name, State, error) + + // Dialog is the interactive dialog to display if the button is clicked + // (OnClick is executed first). OnDialogSubmit must be provided. + Dialog *model.Dialog + + // Function that is called when the dialog box is submitted. It can return a + // general error, or field-specific errors. On success it returns the name + // of the next step, and the state updates to apply. + OnDialogSubmit func(f *Flow, submitted map[string]interface{}) (Name, State, map[string]string, error) +} + +func NewStep(name Name) Step { + return Step{ + name: name, + template: &model.SlackAttachment{}, + } +} + +func (s Step) WithButton(buttons ...Button) Step { + s.buttons = append(s.buttons, buttons...) + return s +} + +func (s Step) Terminal() Step { + s.terminal = true + return s +} + +func (s Step) OnRender(f func(*Flow)) Step { + s.onRender = f + return s +} + +func (s Step) Next(name Name) Step { + s.forwardTo = name + s.autoForward = true + return s +} + +func (s Step) WithImage(imageURL string) Step { + if u, err := url.Parse(imageURL); err == nil { + if u.Host != "" && (u.Scheme == "http" || u.Scheme == "https") { + s.template.ImageURL = imageURL + } else { + s.template.ImageURL = u.Path + } + } + return s +} + +func (s Step) WithColor(color Color) Step { + s.template.Color = string(color) + return s +} + +func (s Step) WithPretext(text string) Step { + s.template.Pretext = text + return s +} + +func (s Step) WithField(title, value string) Step { + s.template.Fields = append(s.template.Fields, &model.SlackAttachmentField{ + Title: title, + Value: value, + }) + return s +} + +func (s Step) WithTitle(text string) Step { + s.template.Title = text + return s +} + +func (s Step) WithText(text string) Step { + s.template.Text = text + return s +} + +func (s Step) do(f *Flow) (*model.Post, bool, error) { + if s.onRender != nil { + s.onRender(f) + } + + return s.render(f, false, 0) +} + +func (s Step) done(f *Flow, selectedButton int) (*model.Post, error) { + post, _, err := s.render(f, true, selectedButton) + return post, err +} + +func (s Step) render(f *Flow, done bool, selectedButton int) (*model.Post, bool, error) { + sa := f.processAttachment(s.template) + post := model.Post{} + model.ParseSlackAttachment(&post, []*model.SlackAttachment{sa}) + + if s.terminal { + // Nothing else to do, do not display buttons on terminal posts. + return &post, true, nil + } + + buttons := processButtons(s.buttons, f.state.AppState) + + attachments, ok := post.GetProp("attachments").([]*model.SlackAttachment) + if !ok || len(attachments) != 1 { + return nil, false, errors.New("expected 1 slack attachment") + } + var actions []*model.PostAction + if done { + if selectedButton > 0 { + action := renderButton(buttons[selectedButton-1], s.name, selectedButton, f.state.AppState) + action.Disabled = true + actions = append(actions, action) + } + } else { + for i, b := range buttons { + actions = append(actions, renderButton(b, s.name, i+1, f.state.AppState)) + } + } + attachments[0].Actions = actions + return &post, false, nil +} + +func (f *Flow) processAttachment(attachment *model.SlackAttachment) *model.SlackAttachment { + if attachment == nil { + return &model.SlackAttachment{Text: "ERROR"} + } + a := *attachment + a.Pretext = formatState(attachment.Pretext, f.state.AppState) + a.Title = formatState(attachment.Title, f.state.AppState) + a.Text = formatState(attachment.Text, f.state.AppState) + + for _, field := range a.Fields { + field.Title = formatState(field.Title, f.state.AppState) + v := field.Value.(string) + if v != "" { + field.Value = formatState(v, f.state.AppState) + } + } + + a.Fallback = fmt.Sprintf("%s: %s", a.Title, a.Text) + + if attachment.ImageURL != "" { + if u, err := url.Parse(attachment.ImageURL); err == nil { + if u.Host != "" && (u.Scheme == "http" || u.Scheme == "https") { + a.ImageURL = attachment.ImageURL + } else { + a.ImageURL = f.pluginURL + "/" + strings.TrimPrefix(attachment.ImageURL, "/") + } + } + } + + return &a +} + +func processButtons(in []Button, state State) []Button { + var out []Button + for _, b := range in { + button := b + button.Name = formatState(b.Name, state) + out = append(out, button) + } + return out +} + +func processDialog(in *model.Dialog, state State) model.Dialog { + d := *in + d.Title = formatState(d.Title, state) + d.IntroductionText = formatState(d.IntroductionText, state) + d.SubmitLabel = formatState(d.SubmitLabel, state) + for i := range d.Elements { + d.Elements[i].DisplayName = formatState(d.Elements[i].DisplayName, state) + d.Elements[i].Name = formatState(d.Elements[i].Name, state) + d.Elements[i].Default = formatState(d.Elements[i].Default, state) + d.Elements[i].Placeholder = formatState(d.Elements[i].Placeholder, state) + d.Elements[i].HelpText = formatState(d.Elements[i].HelpText, state) + } + return d +} + +func renderButton(b Button, stepName Name, i int, state State) *model.PostAction { + return &model.PostAction{ + Name: formatState(b.Name, state), + Disabled: b.Disabled, + Style: string(b.Color), + Integration: &model.PostActionIntegration{ + Context: map[string]interface{}{ + contextStepKey: string(stepName), + contextButtonKey: strconv.Itoa(i), + }, + }, + } +} + +func buttonContext(request *model.PostActionIntegrationRequest) (Name, int, error) { + fromString, ok := request.Context[contextStepKey].(string) + if !ok { + return "", 0, errors.New("missing step name") + } + fromName := Name(fromString) + + buttonStr, ok := request.Context[contextButtonKey].(string) + if !ok { + return "", 0, errors.New("missing button id") + } + buttonIndex, err := strconv.Atoi(buttonStr) + if err != nil { + return "", 0, errors.Wrap(err, "invalid button number") + } + + return fromName, buttonIndex, nil +} + +func dialogContext(request *model.SubmitDialogRequest) (Name, int, error) { + data := strings.Split(request.State, ",") + if len(data) != 2 { + return "", 0, errors.New("invalid request") + } + fromName := Name(data[0]) + buttonIndex, err := strconv.Atoi(data[1]) + if err != nil { + return "", 0, errors.Wrap(err, "malformed button number") + } + return fromName, buttonIndex, nil +} diff --git a/server/public/pluginapi/experimental/oauther/mock_oauther/mock_oauther.go b/server/public/pluginapi/experimental/oauther/mock_oauther/mock_oauther.go new file mode 100644 index 0000000000..22c1f8d594 --- /dev/null +++ b/server/public/pluginapi/experimental/oauther/mock_oauther/mock_oauther.go @@ -0,0 +1,77 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/mattermost/mattermost-plugin-mscalendar/server/utils/oauther (interfaces: OAuther) + +// Package mock_oauther is a generated GoMock package. +package mock_oauther + +import ( + gomock "github.com/golang/mock/gomock" + oauth2 "golang.org/x/oauth2" + reflect "reflect" +) + +// MockOAuther is a mock of OAuther interface +type MockOAuther struct { + ctrl *gomock.Controller + recorder *MockOAutherMockRecorder +} + +// MockOAutherMockRecorder is the mock recorder for MockOAuther +type MockOAutherMockRecorder struct { + mock *MockOAuther +} + +// NewMockOAuther creates a new mock instance +func NewMockOAuther(ctrl *gomock.Controller) *MockOAuther { + mock := &MockOAuther{ctrl: ctrl} + mock.recorder = &MockOAutherMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockOAuther) EXPECT() *MockOAutherMockRecorder { + return m.recorder +} + +// Deauth mocks base method +func (m *MockOAuther) Deauth(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Deauth", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Deauth indicates an expected call of Deauth +func (mr *MockOAutherMockRecorder) Deauth(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Deauth", reflect.TypeOf((*MockOAuther)(nil).Deauth), arg0) +} + +// GetToken mocks base method +func (m *MockOAuther) GetToken(arg0 string) (*oauth2.Token, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetToken", arg0) + ret0, _ := ret[0].(*oauth2.Token) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetToken indicates an expected call of GetToken +func (mr *MockOAutherMockRecorder) GetToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetToken", reflect.TypeOf((*MockOAuther)(nil).GetToken), arg0) +} + +// GetURL mocks base method +func (m *MockOAuther) GetURL() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetURL") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetURL indicates an expected call of GetURL +func (mr *MockOAutherMockRecorder) GetURL() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetURL", reflect.TypeOf((*MockOAuther)(nil).GetURL)) +} diff --git a/server/public/pluginapi/experimental/oauther/mocks/mock_oauther.go b/server/public/pluginapi/experimental/oauther/mocks/mock_oauther.go new file mode 100644 index 0000000000..83e6cc7256 --- /dev/null +++ b/server/public/pluginapi/experimental/oauther/mocks/mock_oauther.go @@ -0,0 +1,105 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/mattermost/mattermost/server/public/pluginapi/experimental/oauther (interfaces: OAuther) + +// Package mock_oauther is a generated GoMock package. +package mock_oauther + +import ( + http "net/http" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + oauth2 "golang.org/x/oauth2" +) + +// MockOAuther is a mock of OAuther interface. +type MockOAuther struct { + ctrl *gomock.Controller + recorder *MockOAutherMockRecorder +} + +// MockOAutherMockRecorder is the mock recorder for MockOAuther. +type MockOAutherMockRecorder struct { + mock *MockOAuther +} + +// NewMockOAuther creates a new mock instance. +func NewMockOAuther(ctrl *gomock.Controller) *MockOAuther { + mock := &MockOAuther{ctrl: ctrl} + mock.recorder = &MockOAutherMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOAuther) EXPECT() *MockOAutherMockRecorder { + return m.recorder +} + +// AddPayload mocks base method. +func (m *MockOAuther) AddPayload(arg0 string, arg1 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddPayload", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddPayload indicates an expected call of AddPayload. +func (mr *MockOAutherMockRecorder) AddPayload(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddPayload", reflect.TypeOf((*MockOAuther)(nil).AddPayload), arg0, arg1) +} + +// Deauthorize mocks base method. +func (m *MockOAuther) Deauthorize(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Deauthorize", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Deauthorize indicates an expected call of Deauthorize. +func (mr *MockOAutherMockRecorder) Deauthorize(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Deauthorize", reflect.TypeOf((*MockOAuther)(nil).Deauthorize), arg0) +} + +// GetConnectURL mocks base method. +func (m *MockOAuther) GetConnectURL() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetConnectURL") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetConnectURL indicates an expected call of GetConnectURL. +func (mr *MockOAutherMockRecorder) GetConnectURL() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConnectURL", reflect.TypeOf((*MockOAuther)(nil).GetConnectURL)) +} + +// GetToken mocks base method. +func (m *MockOAuther) GetToken(arg0 string) (*oauth2.Token, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetToken", arg0) + ret0, _ := ret[0].(*oauth2.Token) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetToken indicates an expected call of GetToken. +func (mr *MockOAutherMockRecorder) GetToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetToken", reflect.TypeOf((*MockOAuther)(nil).GetToken), arg0) +} + +// ServeHTTP mocks base method. +func (m *MockOAuther) ServeHTTP(arg0 http.ResponseWriter, arg1 *http.Request) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ServeHTTP", arg0, arg1) +} + +// ServeHTTP indicates an expected call of ServeHTTP. +func (mr *MockOAutherMockRecorder) ServeHTTP(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeHTTP", reflect.TypeOf((*MockOAuther)(nil).ServeHTTP), arg0, arg1) +} diff --git a/server/public/pluginapi/experimental/oauther/oauth2.go b/server/public/pluginapi/experimental/oauther/oauth2.go new file mode 100644 index 0000000000..ccbaea2c13 --- /dev/null +++ b/server/public/pluginapi/experimental/oauther/oauth2.go @@ -0,0 +1,191 @@ +// Copyright (c) 2019-present Mattermost, Inc. All Rights Reserved. +// See License for license information. + +package oauther + +import ( + "net/http" + "time" + + "golang.org/x/oauth2" + + "github.com/mattermost/mattermost/server/public/pluginapi" + "github.com/mattermost/mattermost/server/public/pluginapi/experimental/bot/logger" + "github.com/mattermost/mattermost/server/public/pluginapi/experimental/common" +) + +const ( + // DefaultStorePrefix is the prefix used when storing information in the KVStore by default. + DefaultStorePrefix = "oauth_" + // DefaultOAuthURL is the URL the OAuther will use to register its endpoints by default. + DefaultOAuthURL = "/oauth2" + // DefaultConnectedString is the string shown to the user when the oauth flow is completed by default. + DefaultConnectedString = "Successfully connected. Please close this window." + // DefaultOAuth2StateTimeToLive is the duration the states from the OAuth flow will live in the KVStore by default. + DefaultOAuth2StateTimeToLive = 5 * time.Minute + // DefaultPayloadTimeToLive is the duration the user payload will live in the KVStore by default. + DefaultPayloadTimeToLive = 10 * time.Minute +) + +const ( + connectURL = "/connect" + completeURL = "/complete" +) + +// OAuther defines an object able to perform the OAuth flow. +type OAuther interface { + // GetToken returns the oauth token for userID, or error if it does not exist or there is any store error. + GetToken(userID string) (*oauth2.Token, error) + // GetConnectURL returns the URL to reach in order to start the OAuth flow. + GetConnectURL() string + // Deauthorize removes the token for userID. Return error if there is any store error. + Deauthorize(userID string) error + // ServeHTTP implements http.Handler + ServeHTTP(w http.ResponseWriter, r *http.Request) + // AddPayload stores some information to be returned after the flow is over + AddPayload(userID string, payload []byte) error +} + +type oAuther struct { + pluginURL string + config oauth2.Config + onConnect func(userID string, token oauth2.Token, payload []byte) + store common.KVStore + logger logger.Logger + storePrefix string + oAuthURL string + connectedString string + oAuth2StateTimeToLive time.Duration + payloadTimeToLive time.Duration +} + +/* +New creates a new OAuther. + +- pluginURL: The base URL for the plugin (e.g. https://www.instance.com/plugins/pluginid). + +- oAuthConfig: The configuration of the Authorization flow to perform. + +- onConnect: What to do when the Authorization process is complete. + +- store: A KVStore to store the data of the OAuther. + +- l Logger: A logger to log errors during authorization. + +- options: Optional options for the OAuther. Available options are StorePrefix, OAuthURL, ConnectedString and OAuth2StateTimeToLive. +*/ +func New( + pluginURL string, + oAuthConfig oauth2.Config, + onConnect func(userID string, token oauth2.Token, payload []byte), + store common.KVStore, + l logger.Logger, + options ...Option, +) OAuther { + o := &oAuther{ + pluginURL: pluginURL, + config: oAuthConfig, + onConnect: onConnect, + store: store, + logger: l, + storePrefix: DefaultStorePrefix, + oAuthURL: DefaultOAuthURL, + connectedString: DefaultConnectedString, + oAuth2StateTimeToLive: DefaultOAuth2StateTimeToLive, + payloadTimeToLive: DefaultPayloadTimeToLive, + } + + for _, option := range options { + option(o) + } + + o.config.RedirectURL = o.pluginURL + o.oAuthURL + "/complete" + + return o +} + +/* +NewFromClient creates a new OAuther from the plugin api client. + +- pluginapi: A plugin api client. + +- pluginID: The plugin ID. + +- oAuthConfig: The configuration of the Authorization flow to perform. + +- onConnect: What to do when the Authorization process is complete. + +- l Logger: A logger to log errors during authorization. + +- options: Optional options for the OAuther. Available options are StorePrefix, OAuthURL, ConnectedString and OAuth2StateTimeToLive. +*/ +func NewFromClient( + client *pluginapi.Client, + oAuthConfig oauth2.Config, + onConnect func(userID string, token oauth2.Token, payload []byte), + l logger.Logger, + options ...Option, +) OAuther { + return New( + common.GetPluginURL(client), + oAuthConfig, + onConnect, + &client.KV, + l, + options..., + ) +} + +func (o *oAuther) GetConnectURL() string { + return o.pluginURL + o.oAuthURL + "/connect" +} + +func (o *oAuther) GetToken(userID string) (*oauth2.Token, error) { + var token *oauth2.Token + err := o.store.Get(o.getTokenKey(userID), &token) + if err != nil { + return nil, err + } + return token, nil +} + +func (o *oAuther) getTokenKey(userID string) string { + return o.storePrefix + "token_" + userID +} + +func (o *oAuther) getStateKey(userID string) string { + return o.storePrefix + "state_" + userID +} + +func (o *oAuther) getPayloadKey(userID string) string { + return o.storePrefix + "payload_" + userID +} + +func (o *oAuther) Deauthorize(userID string) error { + err := o.store.Delete(o.getTokenKey(userID)) + if err != nil { + return err + } + + return nil +} + +func (o *oAuther) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case o.oAuthURL + connectURL: + o.oauth2Connect(w, r) + case o.oAuthURL + completeURL: + o.oauth2Complete(w, r) + default: + http.NotFound(w, r) + } +} + +func (o *oAuther) AddPayload(userID string, payload []byte) error { + _, err := o.store.Set(o.getPayloadKey(userID), payload, pluginapi.SetExpiry(o.payloadTimeToLive)) + if err != nil { + return err + } + + return nil +} diff --git a/server/public/pluginapi/experimental/oauther/oauth2_complete.go b/server/public/pluginapi/experimental/oauther/oauth2_complete.go new file mode 100644 index 0000000000..486355c9aa --- /dev/null +++ b/server/public/pluginapi/experimental/oauther/oauth2_complete.go @@ -0,0 +1,105 @@ +// Copyright (c) 2019-present Mattermost, Inc. All Rights Reserved. +// See License for license information. + +package oauther + +import ( + "context" + "fmt" + "net/http" + "strings" +) + +func (o *oAuther) oauth2Complete(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + authedUserID := r.Header.Get("Mattermost-User-ID") + if authedUserID == "" { + o.logger.Debugf("oauth2Complete: reached by non authed user") + http.Error(w, "Not authorized", http.StatusUnauthorized) + return + } + code := r.URL.Query().Get("code") + if code == "" { + o.logger.Debugf("oauth2Complete: reached with no code") + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + state := r.URL.Query().Get("state") + + var storedState string + err := o.store.Get(o.getStateKey(authedUserID), &storedState) + if err != nil { + o.logger.Warnf("oauth2Complete: cannot get state, err=%s", err.Error()) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + if storedState != state { + o.logger.Debugf("oauth2Complete: state mismatch") + o.logger.Debugf("received state '%s'; expected state '%s%", state, storedState) + http.Error(w, "Not authorized", http.StatusUnauthorized) + return + } + + userID := strings.Split(state, "_")[1] + if userID != authedUserID { + o.logger.Debugf("oauth2Complete: authed user mismatch") + http.Error(w, "Not authorized", http.StatusUnauthorized) + return + } + + ctx := context.Background() + token, err := o.config.Exchange(ctx, code) + if err != nil { + o.logger.Warnf("oauth2Complete: could not generate token, err=%s", err.Error()) + http.Error(w, "Not authorized", http.StatusUnauthorized) + return + } + + var payload []byte + err = o.store.Get(o.getPayloadKey(userID), &payload) + if err != nil { + o.logger.Errorf("oauth2Complete: could not fetch payload, err=&s", err.Error()) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + ok, err := o.store.Set(o.getTokenKey(userID), token) + if err != nil { + o.logger.Errorf("oauth2Complete: cannot store the token, err=%s", err.Error()) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + if !ok { + o.logger.Errorf("oauth2Complete: cannot store token without error") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + html := fmt.Sprintf(` + + + + + + +

%s

+ + + `, o.connectedString) + + w.Header().Set("Content-Type", "text/html") + _, err = w.Write([]byte(html)) + if err != nil { + o.logger.Errorf("oauth2Complete: error writing response, err=%s", err.Error()) + } + + if o.onConnect != nil { + o.onConnect(userID, *token, payload) + } +} diff --git a/server/public/pluginapi/experimental/oauther/oauth2_connect.go b/server/public/pluginapi/experimental/oauther/oauth2_connect.go new file mode 100644 index 0000000000..bc4822942c --- /dev/null +++ b/server/public/pluginapi/experimental/oauther/oauth2_connect.go @@ -0,0 +1,38 @@ +// Copyright (c) 2019-present Mattermost, Inc. All Rights Reserved. +// See License for license information. + +package oauther + +import ( + "fmt" + "net/http" + + "golang.org/x/oauth2" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/pluginapi" +) + +func (o *oAuther) oauth2Connect(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + userID := r.Header.Get("Mattermost-User-ID") + if userID == "" { + o.logger.Debugf("oauth2Connect: reached by non authed user") + http.Error(w, "Not authorized", http.StatusUnauthorized) + return + } + + state := fmt.Sprintf("%v_%v", model.NewId()[0:15], userID) + _, err := o.store.Set(o.getStateKey(userID), state, pluginapi.SetExpiry(o.oAuth2StateTimeToLive)) + if err != nil { + o.logger.Errorf("oauth2Connect: failed to store state, err=%s", err.Error()) + http.Error(w, "failed to store token state", http.StatusInternalServerError) + return + } + + redirectURL := o.config.AuthCodeURL(state, oauth2.AccessTypeOffline) + http.Redirect(w, r, redirectURL, http.StatusFound) +} diff --git a/server/public/pluginapi/experimental/oauther/options.go b/server/public/pluginapi/experimental/oauther/options.go new file mode 100644 index 0000000000..c371218265 --- /dev/null +++ b/server/public/pluginapi/experimental/oauther/options.go @@ -0,0 +1,47 @@ +package oauther + +import "time" + +// Option defines each option that can be passed in the creation of the OAuther. +// Options functions available are OAuthURL, StorePrefix, ConnectedString and OAuth2StateTimeToLive and PayloadTimeToLive. +type Option func(*oAuther) + +// OAuthURL defines the URL the OAuther will use to register its endpoints. +// Defaults to "/oauth2". +func OAuthURL(url string) Option { + return func(o *oAuther) { + o.oAuthURL = url + } +} + +// StorePrefix defines the prefix the OAuther will use to store information in the KVStore. +// Defaults to "oauth_". +func StorePrefix(prefix string) Option { + return func(o *oAuther) { + o.storePrefix = prefix + } +} + +// ConnectedString defines the string shown to the user when the oauth flow is completed. +// Defaults to "Successfully connected. Please close this window.". +func ConnectedString(text string) Option { + return func(o *oAuther) { + o.connectedString = text + } +} + +// OAuth2StateTimeToLive is the duration the states from the OAuth flow will live in the KVStore. +// Defaults to 5 minutes. +func OAuth2StateTimeToLive(ttl time.Duration) Option { + return func(o *oAuther) { + o.oAuth2StateTimeToLive = ttl + } +} + +// PayloadTimeToLive is the duration the payload from the OAuth flow will live in the KVStore. +// Defaults to 10 minutes. +func PayloadTimeToLive(ttl time.Duration) Option { + return func(o *oAuther) { + o.payloadTimeToLive = ttl + } +} diff --git a/server/public/pluginapi/experimental/panel/handler.go b/server/public/pluginapi/experimental/panel/handler.go new file mode 100644 index 0000000000..72bbeb735b --- /dev/null +++ b/server/public/pluginapi/experimental/panel/handler.go @@ -0,0 +1,70 @@ +package panel + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + "github.com/pkg/errors" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/pluginapi/experimental/common" + "github.com/mattermost/mattermost/server/public/pluginapi/experimental/panel/settings" +) + +type handler struct { + panel Panel +} + +func Init(r *mux.Router, panel Panel) { + sh := &handler{ + panel: panel, + } + + panelRouter := r.PathPrefix("/").Subrouter() + panelRouter.HandleFunc(panel.URL(), sh.handleAction).Methods(http.MethodPost) +} + +func (sh *handler) handleAction(w http.ResponseWriter, r *http.Request) { + mattermostUserID := r.Header.Get("Mattermost-User-ID") + if mattermostUserID == "" { + common.SlackAttachmentError(w, errors.New("Not authorized")) + return + } + + var request model.PostActionIntegrationRequest + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + common.SlackAttachmentError(w, errors.New("invalid request")) + return + } + + id, ok := request.Context[settings.ContextIDKey] + if !ok { + common.SlackAttachmentError(w, errors.New("missing setting id")) + return + } + + value, ok := request.Context[settings.ContextButtonValueKey] + if !ok { + value, ok = request.Context[settings.ContextOptionValueKey] + if !ok { + common.SlackAttachmentError(w, errors.New("valid key not found")) + return + } + } + + idString := id.(string) + err := sh.panel.Set(mattermostUserID, idString, value) + if err != nil { + common.SlackAttachmentError(w, errors.Wrap(err, "cannot save setting")) + return + } + + response := model.PostActionIntegrationResponse{} + post, err := sh.panel.ToPost(mattermostUserID) + if err == nil { + response.Update = post + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) +} diff --git a/server/public/pluginapi/experimental/panel/mocks/mock_panel.go b/server/public/pluginapi/experimental/panel/mocks/mock_panel.go new file mode 100644 index 0000000000..c23eb558b3 --- /dev/null +++ b/server/public/pluginapi/experimental/panel/mocks/mock_panel.go @@ -0,0 +1,118 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/mattermost/mattermost/server/public/pluginapi/experimental/panel (interfaces: Panel) + +// Package mock_panel is a generated GoMock package. +package mock_panel + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + model "github.com/mattermost/mattermost/server/public/model" +) + +// MockPanel is a mock of Panel interface. +type MockPanel struct { + ctrl *gomock.Controller + recorder *MockPanelMockRecorder +} + +// MockPanelMockRecorder is the mock recorder for MockPanel. +type MockPanelMockRecorder struct { + mock *MockPanel +} + +// NewMockPanel creates a new mock instance. +func NewMockPanel(ctrl *gomock.Controller) *MockPanel { + mock := &MockPanel{ctrl: ctrl} + mock.recorder = &MockPanelMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPanel) EXPECT() *MockPanelMockRecorder { + return m.recorder +} + +// Clear mocks base method. +func (m *MockPanel) Clear(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Clear", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Clear indicates an expected call of Clear. +func (mr *MockPanelMockRecorder) Clear(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clear", reflect.TypeOf((*MockPanel)(nil).Clear), arg0) +} + +// GetSettingIDs mocks base method. +func (m *MockPanel) GetSettingIDs() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSettingIDs") + ret0, _ := ret[0].([]string) + return ret0 +} + +// GetSettingIDs indicates an expected call of GetSettingIDs. +func (mr *MockPanelMockRecorder) GetSettingIDs() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSettingIDs", reflect.TypeOf((*MockPanel)(nil).GetSettingIDs)) +} + +// Print mocks base method. +func (m *MockPanel) Print(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Print", arg0) +} + +// Print indicates an expected call of Print. +func (mr *MockPanelMockRecorder) Print(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Print", reflect.TypeOf((*MockPanel)(nil).Print), arg0) +} + +// Set mocks base method. +func (m *MockPanel) Set(arg0, arg1 string, arg2 interface{}) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Set", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// Set indicates an expected call of Set. +func (mr *MockPanelMockRecorder) Set(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockPanel)(nil).Set), arg0, arg1, arg2) +} + +// ToPost mocks base method. +func (m *MockPanel) ToPost(arg0 string) (*model.Post, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ToPost", arg0) + ret0, _ := ret[0].(*model.Post) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ToPost indicates an expected call of ToPost. +func (mr *MockPanelMockRecorder) ToPost(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ToPost", reflect.TypeOf((*MockPanel)(nil).ToPost), arg0) +} + +// URL mocks base method. +func (m *MockPanel) URL() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "URL") + ret0, _ := ret[0].(string) + return ret0 +} + +// URL indicates an expected call of URL. +func (mr *MockPanelMockRecorder) URL() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "URL", reflect.TypeOf((*MockPanel)(nil).URL)) +} diff --git a/server/public/pluginapi/experimental/panel/mocks/mock_panelStore.go b/server/public/pluginapi/experimental/panel/mocks/mock_panelStore.go new file mode 100644 index 0000000000..90383b3314 --- /dev/null +++ b/server/public/pluginapi/experimental/panel/mocks/mock_panelStore.go @@ -0,0 +1,77 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/mattermost/mattermost/server/public/pluginapi/experimental/panel (interfaces: Store) + +// Package mock_panel is a generated GoMock package. +package mock_panel + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockStore is a mock of Store interface. +type MockStore struct { + ctrl *gomock.Controller + recorder *MockStoreMockRecorder +} + +// MockStoreMockRecorder is the mock recorder for MockStore. +type MockStoreMockRecorder struct { + mock *MockStore +} + +// NewMockStore creates a new mock instance. +func NewMockStore(ctrl *gomock.Controller) *MockStore { + mock := &MockStore{ctrl: ctrl} + mock.recorder = &MockStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStore) EXPECT() *MockStoreMockRecorder { + return m.recorder +} + +// DeletePanelPostID mocks base method. +func (m *MockStore) DeletePanelPostID(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeletePanelPostID", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeletePanelPostID indicates an expected call of DeletePanelPostID. +func (mr *MockStoreMockRecorder) DeletePanelPostID(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePanelPostID", reflect.TypeOf((*MockStore)(nil).DeletePanelPostID), arg0) +} + +// GetPanelPostID mocks base method. +func (m *MockStore) GetPanelPostID(arg0 string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPanelPostID", arg0) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPanelPostID indicates an expected call of GetPanelPostID. +func (mr *MockStoreMockRecorder) GetPanelPostID(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPanelPostID", reflect.TypeOf((*MockStore)(nil).GetPanelPostID), arg0) +} + +// SetPanelPostID mocks base method. +func (m *MockStore) SetPanelPostID(arg0, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetPanelPostID", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetPanelPostID indicates an expected call of SetPanelPostID. +func (mr *MockStoreMockRecorder) SetPanelPostID(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetPanelPostID", reflect.TypeOf((*MockStore)(nil).SetPanelPostID), arg0, arg1) +} diff --git a/server/public/pluginapi/experimental/panel/mocks/mock_setting.go b/server/public/pluginapi/experimental/panel/mocks/mock_setting.go new file mode 100644 index 0000000000..38b87f44ac --- /dev/null +++ b/server/public/pluginapi/experimental/panel/mocks/mock_setting.go @@ -0,0 +1,149 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/mattermost/mattermost/server/public/pluginapi/experimental/panel/settings (interfaces: Setting) + +// Package mock_panel is a generated GoMock package. +package mock_panel + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + model "github.com/mattermost/mattermost/server/public/model" +) + +// MockSetting is a mock of Setting interface. +type MockSetting struct { + ctrl *gomock.Controller + recorder *MockSettingMockRecorder +} + +// MockSettingMockRecorder is the mock recorder for MockSetting. +type MockSettingMockRecorder struct { + mock *MockSetting +} + +// NewMockSetting creates a new mock instance. +func NewMockSetting(ctrl *gomock.Controller) *MockSetting { + mock := &MockSetting{ctrl: ctrl} + mock.recorder = &MockSettingMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSetting) EXPECT() *MockSettingMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockSetting) Get(arg0 string) (interface{}, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].(interface{}) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockSettingMockRecorder) Get(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSetting)(nil).Get), arg0) +} + +// GetDependency mocks base method. +func (m *MockSetting) GetDependency() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDependency") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetDependency indicates an expected call of GetDependency. +func (mr *MockSettingMockRecorder) GetDependency() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDependency", reflect.TypeOf((*MockSetting)(nil).GetDependency)) +} + +// GetDescription mocks base method. +func (m *MockSetting) GetDescription() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDescription") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetDescription indicates an expected call of GetDescription. +func (mr *MockSettingMockRecorder) GetDescription() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDescription", reflect.TypeOf((*MockSetting)(nil).GetDescription)) +} + +// GetID mocks base method. +func (m *MockSetting) GetID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetID") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetID indicates an expected call of GetID. +func (mr *MockSettingMockRecorder) GetID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockSetting)(nil).GetID)) +} + +// GetSlackAttachments mocks base method. +func (m *MockSetting) GetSlackAttachments(arg0, arg1 string, arg2 bool) (*model.SlackAttachment, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSlackAttachments", arg0, arg1, arg2) + ret0, _ := ret[0].(*model.SlackAttachment) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSlackAttachments indicates an expected call of GetSlackAttachments. +func (mr *MockSettingMockRecorder) GetSlackAttachments(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSlackAttachments", reflect.TypeOf((*MockSetting)(nil).GetSlackAttachments), arg0, arg1, arg2) +} + +// GetTitle mocks base method. +func (m *MockSetting) GetTitle() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTitle") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetTitle indicates an expected call of GetTitle. +func (mr *MockSettingMockRecorder) GetTitle() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTitle", reflect.TypeOf((*MockSetting)(nil).GetTitle)) +} + +// IsDisabled mocks base method. +func (m *MockSetting) IsDisabled(arg0 interface{}) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsDisabled", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsDisabled indicates an expected call of IsDisabled. +func (mr *MockSettingMockRecorder) IsDisabled(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsDisabled", reflect.TypeOf((*MockSetting)(nil).IsDisabled), arg0) +} + +// Set mocks base method. +func (m *MockSetting) Set(arg0 string, arg1 interface{}) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Set", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Set indicates an expected call of Set. +func (mr *MockSettingMockRecorder) Set(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockSetting)(nil).Set), arg0, arg1) +} diff --git a/server/public/pluginapi/experimental/panel/panel.go b/server/public/pluginapi/experimental/panel/panel.go new file mode 100644 index 0000000000..e5e9582576 --- /dev/null +++ b/server/public/pluginapi/experimental/panel/panel.go @@ -0,0 +1,171 @@ +package panel + +import ( + "errors" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/pluginapi/experimental/bot/logger" + "github.com/mattermost/mattermost/server/public/pluginapi/experimental/bot/poster" + "github.com/mattermost/mattermost/server/public/pluginapi/experimental/common" + "github.com/mattermost/mattermost/server/public/pluginapi/experimental/panel/settings" +) + +type Panel interface { + Set(userID, settingID string, value interface{}) error + Print(userID string) + ToPost(userID string) (*model.Post, error) + Clear(userID string) error + URL() string + GetSettingIDs() []string +} + +type panel struct { + settings map[string]settings.Setting + settingKeys []string + poster poster.Poster + logger logger.Logger + store Store + settingHandler string + pluginURL string +} + +func NewSettingsPanel( + settingList []settings.Setting, + p poster.Poster, + l logger.Logger, + store Store, + settingHandler, + pluginURL string, +) Panel { + settingsMap := make(map[string]settings.Setting) + settingKeys := []string{} + for _, s := range settingList { + settingsMap[s.GetID()] = s + settingKeys = append(settingKeys, s.GetID()) + } + + panel := &panel{ + settings: settingsMap, + settingKeys: settingKeys, + poster: p, + logger: l, + store: store, + settingHandler: settingHandler, + pluginURL: pluginURL, + } + + return panel +} + +func (p *panel) Set(userID, settingID string, value interface{}) error { + s, ok := p.settings[settingID] + if !ok { + return errors.New("cannot find setting " + settingID) + } + + err := s.Set(userID, value) + if err != nil { + return err + } + return nil +} + +func (p *panel) GetSettingIDs() []string { + return p.settingKeys +} + +func (p *panel) URL() string { + return p.settingHandler +} + +func (p *panel) Print(userID string) { + err := p.cleanPreviousSettingsPosts(userID) + if err != nil { + p.logger.Errorf("could not clean previous setting post, " + err.Error()) + } + + sas := []*model.SlackAttachment{} + for _, key := range p.settingKeys { + s := p.settings[key] + sa, loopErr := s.GetSlackAttachments(userID, p.pluginURL+p.settingHandler, p.isSettingDisabled(userID, s)) + if loopErr != nil { + p.logger.Errorf("error creating the slack attachment, err=" + loopErr.Error()) + continue + } + sas = append(sas, sa) + } + postID, err := p.poster.DMWithAttachments(userID, sas...) + if err != nil { + p.logger.Errorf("error creating the message, err=", err.Error()) + return + } + + err = p.store.SetPanelPostID(userID, postID) + if err != nil { + p.logger.Errorf("could not set the post IDs, err=", err.Error()) + } +} + +func (p *panel) ToPost(userID string) (*model.Post, error) { + post := &model.Post{} + + sas := []*model.SlackAttachment{} + for _, key := range p.settingKeys { + s := p.settings[key] + sa, err := s.GetSlackAttachments(userID, p.pluginURL+p.settingHandler, p.isSettingDisabled(userID, s)) + if err != nil { + p.logger.Errorf("error creating the slack attachment for setting %s, err=%s", s.GetID(), err.Error()) + continue + } + sas = append(sas, sa) + } + + model.ParseSlackAttachment(post, sas) + return post, nil +} + +func (p *panel) cleanPreviousSettingsPosts(userID string) error { + postID, err := p.store.GetPanelPostID(userID) + if err == common.ErrNotFound { + return nil + } + + if err != nil { + return err + } + + err = p.poster.DeletePost(postID) + if err != nil { + p.logger.Errorf("could not delete setting post, %s", err) + } + + err = p.store.DeletePanelPostID(userID) + if err != nil { + return err + } + + return nil +} + +func (p *panel) Clear(userID string) error { + return p.cleanPreviousSettingsPosts(userID) +} + +func (p *panel) isSettingDisabled(userID string, s settings.Setting) bool { + dependencyID := s.GetDependency() + if dependencyID == "" { + return false + } + dependency, ok := p.settings[dependencyID] + if !ok { + p.logger.Errorf("settings dependency %s not found", dependencyID) + return false + } + + value, err := dependency.Get(userID) + if err != nil { + p.logger.Errorf("cannot get dependency %s value", dependencyID) + return false + } + return s.IsDisabled(value) +} diff --git a/server/public/pluginapi/experimental/panel/settings/base_setting.go b/server/public/pluginapi/experimental/panel/settings/base_setting.go new file mode 100644 index 0000000000..96d533e578 --- /dev/null +++ b/server/public/pluginapi/experimental/panel/settings/base_setting.go @@ -0,0 +1,28 @@ +package settings + +type baseSetting struct { + title string + description string + id string + dependsOn string +} + +func (s *baseSetting) GetID() string { + return s.id +} + +func (s *baseSetting) GetTitle() string { + return s.title +} + +func (s *baseSetting) GetDescription() string { + return s.description +} + +func (s *baseSetting) GetDependency() string { + return s.dependsOn +} + +func (s *baseSetting) IsDisabled(foreignValue interface{}) bool { + return false +} diff --git a/server/public/pluginapi/experimental/panel/settings/bool_setting.go b/server/public/pluginapi/experimental/panel/settings/bool_setting.go new file mode 100644 index 0000000000..bb99d35287 --- /dev/null +++ b/server/public/pluginapi/experimental/panel/settings/bool_setting.go @@ -0,0 +1,114 @@ +package settings + +import ( + "errors" + "fmt" + + "github.com/mattermost/mattermost/server/public/model" +) + +type boolSetting struct { + baseSetting + store SettingStore +} + +// NewBoolSetting creates a new setting input for boolean values +func NewBoolSetting(id, title, description, dependsOn string, store SettingStore) Setting { + return &boolSetting{ + baseSetting: baseSetting{ + title: title, + description: description, + id: id, + dependsOn: dependsOn, + }, + store: store, + } +} + +func (s *boolSetting) Set(userID string, value interface{}) error { + boolValue := false + if value == TrueString { + boolValue = true + } + + err := s.store.SetSetting(userID, s.id, boolValue) + if err != nil { + return err + } + + return nil +} + +func (s *boolSetting) Get(userID string) (interface{}, error) { + value, err := s.store.GetSetting(userID, s.id) + if err != nil { + return "", err + } + boolValue, ok := value.(bool) + if !ok { + return "", errors.New("current value is not a bool") + } + + stringValue := FalseString + if boolValue { + stringValue = TrueString + } + + return stringValue, nil +} + +func (s *boolSetting) GetSlackAttachments(userID, settingHandler string, disabled bool) (*model.SlackAttachment, error) { + title := fmt.Sprintf("Setting: %s", s.title) + currentValueMessage := DisabledString + + actions := []*model.PostAction{} + if !disabled { + currentValue, err := s.Get(userID) + if err != nil { + return nil, err + } + + currentTextValue := "No" + if currentValue == TrueString { + currentTextValue = "Yes" + } + currentValueMessage = fmt.Sprintf("Current value: %s", currentTextValue) + + actionTrue := model.PostAction{ + Name: "Yes", + Integration: &model.PostActionIntegration{ + URL: settingHandler, + Context: map[string]interface{}{ + ContextIDKey: s.id, + ContextButtonValueKey: TrueString, + }, + }, + } + + actionFalse := model.PostAction{ + Name: "No", + Integration: &model.PostActionIntegration{ + URL: settingHandler, + Context: map[string]interface{}{ + ContextIDKey: s.id, + ContextButtonValueKey: FalseString, + }, + }, + } + actions = []*model.PostAction{&actionTrue, &actionFalse} + } + + text := fmt.Sprintf("%s\n%s", s.description, currentValueMessage) + sa := model.SlackAttachment{ + Title: title, + Text: text, + Fallback: fmt.Sprintf("%s: %s", title, text), + Actions: actions, + } + + return &sa, nil +} + +func (s *boolSetting) IsDisabled(foreignValue interface{}) bool { + return foreignValue == FalseString +} diff --git a/server/public/pluginapi/experimental/panel/settings/empty_setting.go b/server/public/pluginapi/experimental/panel/settings/empty_setting.go new file mode 100644 index 0000000000..61303a1960 --- /dev/null +++ b/server/public/pluginapi/experimental/panel/settings/empty_setting.go @@ -0,0 +1,41 @@ +package settings + +import ( + "fmt" + + "github.com/mattermost/mattermost/server/public/model" +) + +type emptySetting struct { + baseSetting +} + +// NewEmptySetting creates a new panel value with no setting attached +func NewEmptySetting(id, title, description string) Setting { + return &emptySetting{ + baseSetting: baseSetting{ + id: id, + title: title, + description: description, + }, + } +} + +func (s *emptySetting) GetSlackAttachments(userID, settingHandler string, disabled bool) (*model.SlackAttachment, error) { + title := fmt.Sprintf("Setting: %s", s.title) + sa := model.SlackAttachment{ + Title: title, + Text: s.description, + Fallback: fmt.Sprintf("%s: %s", title, s.description), + } + + return &sa, nil +} + +func (s *emptySetting) Get(userID string) (interface{}, error) { + return nil, nil +} + +func (s *emptySetting) Set(userID string, value interface{}) error { + return nil +} diff --git a/server/public/pluginapi/experimental/panel/settings/option_setting.go b/server/public/pluginapi/experimental/panel/settings/option_setting.go new file mode 100644 index 0000000000..5d90878641 --- /dev/null +++ b/server/public/pluginapi/experimental/panel/settings/option_setting.go @@ -0,0 +1,91 @@ +package settings + +import ( + "errors" + "fmt" + + "github.com/mattermost/mattermost/server/public/model" +) + +type optionSetting struct { + baseSetting + options []string + store SettingStore +} + +// NewOptionSetting creates a new setting input to select from a dropdown +func NewOptionSetting(id, title, description, dependsOn string, options []string, store SettingStore) Setting { + return &optionSetting{ + baseSetting: baseSetting{ + title: title, + description: description, + id: id, + dependsOn: dependsOn, + }, + options: options, + store: store, + } +} + +func (s *optionSetting) Set(userID string, value interface{}) error { + err := s.store.SetSetting(userID, s.id, value) + if err != nil { + return err + } + + return nil +} + +func (s *optionSetting) Get(userID string) (interface{}, error) { + value, err := s.store.GetSetting(userID, s.id) + if err != nil { + return "", err + } + valueString, ok := value.(string) + if !ok { + return "", errors.New("current value is not a string") + } + + return valueString, nil +} + +func (s *optionSetting) GetSlackAttachments(userID, settingHandler string, disabled bool) (*model.SlackAttachment, error) { + title := fmt.Sprintf("Setting: %s", s.title) + currentValueMessage := DisabledString + + actions := []*model.PostAction{} + if !disabled { + currentTextValue, err := s.Get(userID) + if err != nil { + return nil, err + } + currentValueMessage = fmt.Sprintf("Current value: %s", currentTextValue) + + actionOptions := model.PostAction{ + Name: "Select an option:", + Integration: &model.PostActionIntegration{ + URL: settingHandler + "?" + s.id + "=true", + Context: map[string]interface{}{ + ContextIDKey: s.id, + }, + }, + Type: "select", + Options: stringsToOptions(s.options), + } + + actions = []*model.PostAction{&actionOptions} + } + + text := fmt.Sprintf("%s\n%s", s.description, currentValueMessage) + sa := model.SlackAttachment{ + Title: title, + Text: text, + Fallback: fmt.Sprintf("%s: %s", title, text), + Actions: actions, + } + return &sa, nil +} + +func (s *optionSetting) IsDisabled(foreignValue interface{}) bool { + return foreignValue == FalseString +} diff --git a/server/public/pluginapi/experimental/panel/settings/read_only_setting.go b/server/public/pluginapi/experimental/panel/settings/read_only_setting.go new file mode 100644 index 0000000000..3604b4e5d3 --- /dev/null +++ b/server/public/pluginapi/experimental/panel/settings/read_only_setting.go @@ -0,0 +1,69 @@ +package settings + +import ( + "errors" + "fmt" + + "github.com/mattermost/mattermost/server/public/model" +) + +type readOnlySetting struct { + baseSetting + store SettingStore +} + +// NewReadOnlySetting creates a new panel value that only read from the setting +func NewReadOnlySetting(id, title, description, dependsOn string, store SettingStore) Setting { + return &readOnlySetting{ + baseSetting: baseSetting{ + title: title, + description: description, + id: id, + dependsOn: dependsOn, + }, + store: store, + } +} + +func (s *readOnlySetting) Get(userID string) (interface{}, error) { + value, err := s.store.GetSetting(userID, s.id) + if err != nil { + return "", err + } + stringValue, ok := value.(string) + if !ok { + return "", errors.New("current value is not a string") + } + + return stringValue, nil +} + +func (s *readOnlySetting) Set(userID string, value interface{}) error { + return nil +} + +func (s *readOnlySetting) GetSlackAttachments(userID, settingHandler string, disabled bool) (*model.SlackAttachment, error) { + title := fmt.Sprintf("Setting: %s", s.title) + currentValueMessage := DisabledString + + if !disabled { + currentValue, err := s.Get(userID) + if err != nil { + return nil, err + } + currentValueMessage = fmt.Sprintf("Current value: %s", currentValue) + } + + text := fmt.Sprintf("%s\n%s", s.description, currentValueMessage) + sa := model.SlackAttachment{ + Title: title, + Text: text, + Fallback: fmt.Sprintf("%s: %s", title, text), + } + + return &sa, nil +} + +func (s *readOnlySetting) IsDisabled(foreignValue interface{}) bool { + return foreignValue == FalseString +} diff --git a/server/public/pluginapi/experimental/panel/settings/setting.go b/server/public/pluginapi/experimental/panel/settings/setting.go new file mode 100644 index 0000000000..672ebdfd8f --- /dev/null +++ b/server/public/pluginapi/experimental/panel/settings/setting.go @@ -0,0 +1,33 @@ +package settings + +import ( + "github.com/mattermost/mattermost/server/public/model" +) + +const ( + // ContextIDKey defines the key used in the context to store the ID + ContextIDKey = "setting_id" + // ContextButtonValueKey defines the key used in the context to store a button value + ContextButtonValueKey = "button_value" + // ContextOptionValueKey defines the key used in the context to store a selected option value + ContextOptionValueKey = "selected_option" + + // DisabledString defines the string used to show that a setting is disabled + DisabledString = "Disabled" + // TrueString codify the boolean true into a string + TrueString = "true" + // FalseString codify the boolean false into a string + FalseString = "false" +) + +// Setting defines the behavior of each element a the panel +type Setting interface { + Set(userID string, value interface{}) error + Get(userID string) (interface{}, error) + GetID() string + GetDependency() string + IsDisabled(foreignValue interface{}) bool + GetTitle() string + GetDescription() string + GetSlackAttachments(userID, settingHandler string, disabled bool) (*model.SlackAttachment, error) +} diff --git a/server/public/pluginapi/experimental/panel/settings/store.go b/server/public/pluginapi/experimental/panel/settings/store.go new file mode 100644 index 0000000000..45afce6c74 --- /dev/null +++ b/server/public/pluginapi/experimental/panel/settings/store.go @@ -0,0 +1,7 @@ +package settings + +// SettingStore defines the behavior needed to set and get settings +type SettingStore interface { + SetSetting(userID, settingID string, value interface{}) error + GetSetting(userID, settingID string) (interface{}, error) +} diff --git a/server/public/pluginapi/experimental/panel/settings/utils.go b/server/public/pluginapi/experimental/panel/settings/utils.go new file mode 100644 index 0000000000..3a1cfe0951 --- /dev/null +++ b/server/public/pluginapi/experimental/panel/settings/utils.go @@ -0,0 +1,16 @@ +package settings + +import ( + "github.com/mattermost/mattermost/server/public/model" +) + +func stringsToOptions(in []string) []*model.PostActionOptions { + out := make([]*model.PostActionOptions, len(in)) + for i, o := range in { + out[i] = &model.PostActionOptions{ + Text: o, + Value: o, + } + } + return out +} diff --git a/server/public/pluginapi/experimental/panel/store.go b/server/public/pluginapi/experimental/panel/store.go new file mode 100644 index 0000000000..dabc1892c8 --- /dev/null +++ b/server/public/pluginapi/experimental/panel/store.go @@ -0,0 +1,53 @@ +package panel + +import ( + "errors" + + "github.com/mattermost/mattermost/server/public/pluginapi" +) + +type Store interface { + SetPanelPostID(userID string, postID string) error + GetPanelPostID(userID string) (string, error) + DeletePanelPostID(userID string) error +} + +type panelStore struct { + kv *pluginapi.KVService + keyPrefix string +} + +func NewPanelStore(kv *pluginapi.KVService, keyPrefix string) Store { + return &panelStore{ + kv: kv, + keyPrefix: keyPrefix, + } +} + +func (ps *panelStore) SetPanelPostID(userID, postID string) error { + ok, err := ps.kv.Set(ps.getKey(userID), postID) + if err != nil { + return err + } + if !ok { + return errors.New("value not set without errors") + } + return nil +} + +func (ps *panelStore) GetPanelPostID(userID string) (string, error) { + var postID string + err := ps.kv.Get(ps.getKey(userID), &postID) + if err != nil { + return "", err + } + return postID, nil +} + +func (ps *panelStore) DeletePanelPostID(userID string) error { + return ps.kv.Delete(ps.getKey(userID)) +} + +func (ps *panelStore) getKey(userID string) string { + return ps.keyPrefix + "-" + userID +} diff --git a/server/public/pluginapi/experimental/telemetry/doc.go b/server/public/pluginapi/experimental/telemetry/doc.go new file mode 100644 index 0000000000..5c39b63ae3 --- /dev/null +++ b/server/public/pluginapi/experimental/telemetry/doc.go @@ -0,0 +1,76 @@ +// Package telemetry allows you to add telemetry to your plugins. +// For Rudder, you can set the data plane URL and the write key on build time, +// to allow having different keys for production and development. +// If you are working on a Mattermost project, the data plane URL is already set. +// In order to default to the development key we have to set an environment variable during build time. +// Copy the following lines in build/custom.mk to setup that variable. +// +// ifndef MM_RUDDER_WRITE_KEY +// MM_RUDDER_WRITE_KEY = 1d5bMvdrfWClLxgK1FvV3s4U1tg +// endif +// +// To use this environment variable to set the key in the plugin, +// you have to add this line after the previous ones. +// +// LDFLAGS += -X "github.com/mattermost/mattermost/server/public/pluginapi/experimental/telemetry.rudderWriteKey=$(MM_RUDDER_WRITE_KEY)" +// +// MM_RUDDER_WRITE_KEY environment variable must be set also during CI +// to the production write key ("1dP7Oi78p0PK1brYLsfslgnbD1I"). +// If you want to use your own data plane URL, add also this line and +// make sure the MM_RUDDER_DATAPLANE_URL environment variable is set. +// +// LDFLAGS += -X "github.com/mattermost/mattermost/server/public/pluginapi/experimental/telemetry.rudderDataPlaneURL=$(MM_RUDDER_DATAPLANE_URL)" +// +// In order to use telemetry you should: +// +// 1. Add the new fields to the plugin +// +// type Plugin struct { +// plugin.MattermostPlugin +// ... +// telemetryClient telemetry.Client +// tracker telemetry.Tracker +// } +// +// 2. Start the telemetry client and tracker on plugin activate +// +// func (p *Plugin) OnActivate() error { +// p.telemetryClient, err = telemetry.NewRudderClient() +// if err != nil { +// p.API.LogWarn("telemetry client not started", "error", err.Error()) +// } +// ... +// p.tracker = telemetry.NewTracker( +// p.telemetryClient, +// p.API.GetDiagnosticId(), +// p.API.GetServerVersion(), +// Manifest.Id, +// Manifest.Version, +// "plugin_short_namame", +// telemetry.NewTrackerConfig(p.API.GetConfig()), +// logger.New(p.API) +// ) +// } +// +// 3. Trigger tracker changes when configuration changes +// +// func (p *Plugin) OnConfigurationChange() error { +// ... +// if p.tracker != nil { +// p.tracker.ReloadConfig(telemetry.NewTrackerConfig(p.API.GetConfig())) +// } +// return nil +// } +// +// 4. Close the client on plugin deactivate +// +// func (p *Plugin) OnDeactivate() error { +// if p.telemetryClient != nil { +// err := p.telemetryClient.Close() +// if err != nil { +// p.API.LogWarn("OnDeactivate: failed to close telemetryClient", "error", err.Error()) +// } +// } +// return nil +// } +package telemetry diff --git a/server/public/pluginapi/experimental/telemetry/rudder.go b/server/public/pluginapi/experimental/telemetry/rudder.go new file mode 100644 index 0000000000..76306977c6 --- /dev/null +++ b/server/public/pluginapi/experimental/telemetry/rudder.go @@ -0,0 +1,49 @@ +package telemetry + +import ( + rudder "github.com/rudderlabs/analytics-go" +) + +// rudderDataPlaneURL is set to the common Data Plane URL for all Mattermost Projects. +// It can be set during build time. More info in the package documentation. +var rudderDataPlaneURL = "https://pdat.matterlytics.com" + +// rudderWriteKey is set during build time. More info in the package documentation. +var rudderWriteKey string + +// NewRudderClient creates a new telemetry client with Rudder using the default configuration. +func NewRudderClient() (Client, error) { + return NewRudderClientWithCredentials(rudderWriteKey, rudderDataPlaneURL) +} + +// NewRudderClientWithCredentials lets you create a Rudder client with your own credentials. +func NewRudderClientWithCredentials(writeKey, dataPlaneURL string) (Client, error) { + client, err := rudder.NewWithConfig(writeKey, dataPlaneURL, rudder.Config{}) + if err != nil { + return nil, err + } + + return &rudderWrapper{client: client}, nil +} + +type rudderWrapper struct { + client rudder.Client +} + +func (r *rudderWrapper) Enqueue(t Track) error { + var context *rudder.Context + if t.InstallationID != "" { + context = &rudder.Context{Traits: map[string]any{"installationId": t.InstallationID}} + } + + return r.client.Enqueue(rudder.Track{ + UserId: t.UserID, + Event: t.Event, + Context: context, + Properties: t.Properties, + }) +} + +func (r *rudderWrapper) Close() error { + return r.client.Close() +} diff --git a/server/public/pluginapi/experimental/telemetry/tracker.go b/server/public/pluginapi/experimental/telemetry/tracker.go new file mode 100644 index 0000000000..ccb16c5326 --- /dev/null +++ b/server/public/pluginapi/experimental/telemetry/tracker.go @@ -0,0 +1,179 @@ +package telemetry + +import ( + "os" + "sync" + + "github.com/pkg/errors" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/pluginapi/experimental/bot/logger" +) + +type TrackerConfig struct { + EnabledTracking bool + EnabledLogging bool +} + +// NewTrackerConfig returns a new trackerConfig from the current values of the model.Config. +func NewTrackerConfig(config *model.Config) TrackerConfig { + var enabledTracking, enabledLogging bool + if config == nil { + return TrackerConfig{} + } + + if enableDiagnostics := config.LogSettings.EnableDiagnostics; enableDiagnostics != nil { + enabledTracking = *enableDiagnostics + } + + if enableDeveloper := config.ServiceSettings.EnableDeveloper; enableDeveloper != nil { + enabledLogging = *enableDeveloper + } + + return TrackerConfig{ + EnabledTracking: enabledTracking, + EnabledLogging: enabledLogging, + } +} + +// Tracker defines a telemetry tracker +type Tracker interface { + // TrackEvent registers an event through the configured telemetry client + TrackEvent(event string, properties map[string]interface{}) error + // TrackUserEvent registers an event through the configured telemetry client associated to a user + TrackUserEvent(event string, userID string, properties map[string]interface{}) error + // Reload Config re-evaluates tracker config to determine if tracking behavior should change + ReloadConfig(config TrackerConfig) +} + +// Client defines a telemetry client +type Client interface { + // Enqueue adds a tracker event (Track) to be registered + Enqueue(t Track) error + // Close closes the client connection, flushing any event left on the queue + Close() error +} + +// Track defines an event ready for the client to process +type Track struct { + UserID string + Event string + Properties map[string]interface{} + InstallationID string +} + +type tracker struct { + client Client + diagnosticID string + serverVersion string + pluginID string + pluginVersion string + telemetryShortName string + configLock sync.RWMutex + config TrackerConfig + logger logger.Logger +} + +// NewTracker creates a default Tracker +// - c Client: A telemetry client. If nil, the tracker will not track any event. +// - diagnosticID: Server unique ID used for telemetry. +// - severVersion: Mattermost server version. +// - pluginID: The plugin ID. +// - pluginVersion: The plugin version. +// - telemetryShortName: Short name for the plugin to use in telemetry. Used to avoid dot separated names like `com.company.pluginName`. +// If a empty string is provided, it will use the pluginID. +// - config: Whether the system has enabled sending telemetry data. If false, the tracker will not track any event. +// - l Logger: A logger to debug event tracking and some important changes (it won't log if nil is passed as logger). +func NewTracker( + c Client, + diagnosticID, + serverVersion, + pluginID, + pluginVersion, + telemetryShortName string, + config TrackerConfig, + l logger.Logger, +) Tracker { + if telemetryShortName == "" { + telemetryShortName = pluginID + } + return &tracker{ + telemetryShortName: telemetryShortName, + client: c, + diagnosticID: diagnosticID, + serverVersion: serverVersion, + pluginID: pluginID, + pluginVersion: pluginVersion, + logger: l, + config: config, + } +} + +func (t *tracker) ReloadConfig(config TrackerConfig) { + t.configLock.Lock() + defer t.configLock.Unlock() + + if config.EnabledTracking != t.config.EnabledTracking { + if config.EnabledTracking { + t.debugf("Enabling plugin telemetry") + } else { + t.debugf("Disabling plugin telemetry") + } + } + + t.config.EnabledTracking = config.EnabledTracking + t.config.EnabledLogging = config.EnabledLogging +} + +// Note that config lock is handled by the caller. +func (t *tracker) debugf(message string, args ...interface{}) { + if t.logger == nil || !t.config.EnabledLogging { + return + } + t.logger.Debugf(message, args...) +} + +func (t *tracker) TrackEvent(event string, properties map[string]interface{}) error { + t.configLock.RLock() + defer t.configLock.RUnlock() + + event = t.telemetryShortName + "_" + event + if !t.config.EnabledTracking || t.client == nil { + t.debugf("Plugin telemetry event `%s` tracked, but not sent due to configuration", event) + return nil + } + + if properties == nil { + properties = map[string]interface{}{} + } + properties["PluginID"] = t.pluginID + properties["PluginVersion"] = t.pluginVersion + properties["ServerVersion"] = t.serverVersion + + // if we are part of a cloud installation, add it's ID to the tracked event's context. + installationID := os.Getenv("MM_CLOUD_INSTALLATION_ID") + + err := t.client.Enqueue(Track{ + // We consider the server the "user" on the telemetry system. Any reference to the actual user is passed by properties. + UserID: t.diagnosticID, + Event: event, + Properties: properties, + InstallationID: installationID, + }) + + if err != nil { + return errors.Wrap(err, "cannot enqueue the track") + } + t.debugf("Tracked plugin telemetry event `%s`", event) + + return nil +} + +func (t *tracker) TrackUserEvent(event, userID string, properties map[string]interface{}) error { + if properties == nil { + properties = map[string]interface{}{} + } + + properties["UserActualID"] = userID + return t.TrackEvent(event, properties) +} diff --git a/server/public/pluginapi/file.go b/server/public/pluginapi/file.go new file mode 100644 index 0000000000..d295eb414c --- /dev/null +++ b/server/public/pluginapi/file.go @@ -0,0 +1,85 @@ +package pluginapi + +import ( + "bytes" + "io" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" +) + +// FileService exposes methods to manipulate files, most often as post attachments. +type FileService struct { + api plugin.API +} + +// Get gets content of a file by id. +// +// Minimum server version: 5.8 +func (f *FileService) Get(id string) (io.Reader, error) { + contentBytes, appErr := f.api.GetFile(id) + if appErr != nil { + return nil, normalizeAppErr(appErr) + } + + return bytes.NewReader(contentBytes), nil +} + +// GetByPath reads a file by its path on the dist. +// +// Minimum server version: 5.3 +func (f *FileService) GetByPath(path string) (io.Reader, error) { + contentBytes, appErr := f.api.ReadFile(path) + if appErr != nil { + return nil, normalizeAppErr(appErr) + } + + return bytes.NewReader(contentBytes), nil +} + +// GetInfo gets a file's info by id. +// +// Minimum server version: 5.3 +func (f *FileService) GetInfo(id string) (*model.FileInfo, error) { + info, appErr := f.api.GetFileInfo(id) + + return info, normalizeAppErr(appErr) +} + +// GetLink gets the public link of a file by id. +// +// Minimum server version: 5.6 +func (f *FileService) GetLink(id string) (string, error) { + link, appErr := f.api.GetFileLink(id) + + return link, normalizeAppErr(appErr) +} + +// Upload uploads a file to a channel to be later attached to a post. +// +// Minimum server version: 5.6 +func (f *FileService) Upload(content io.Reader, fileName, channelID string) (*model.FileInfo, error) { + contentBytes, err := io.ReadAll(content) + if err != nil { + return nil, err + } + + info, appErr := f.api.UploadFile(contentBytes, channelID, fileName) + + return info, normalizeAppErr(appErr) +} + +// CopyInfos duplicates the FileInfo objects referenced by the given file ids, recording +// the given user id as the new creator and returning the new set of file ids. +// +// The duplicate FileInfo objects are not initially linked to a post, but may now be passed +// on creation of a post. +// Use this API to duplicate a post and its file attachments without actually duplicating +// the uploaded files. +// +// Minimum server version: 5.2 +func (f *FileService) CopyInfos(ids []string, userID string) ([]string, error) { + newIDs, appErr := f.api.CopyFileInfos(userID, ids) + + return newIDs, normalizeAppErr(appErr) +} diff --git a/server/public/pluginapi/file_test.go b/server/public/pluginapi/file_test.go new file mode 100644 index 0000000000..fa15817006 --- /dev/null +++ b/server/public/pluginapi/file_test.go @@ -0,0 +1,185 @@ +package pluginapi_test + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/mattermost/mattermost/server/public/pluginapi" +) + +func TestGetFile(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetFile", "1").Return([]byte{2}, nil) + + content, err := client.File.Get("1") + require.NoError(t, err) + contentBytes, err := io.ReadAll(content) + require.NoError(t, err) + require.Equal(t, []byte{2}, contentBytes) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := newAppError() + + api.On("GetFile", "1").Return(nil, appErr) + + content, err := client.File.Get("1") + require.Equal(t, appErr, err) + require.Zero(t, content) + }) +} + +func TestGetFileByPath(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("ReadFile", "1").Return([]byte{2}, nil) + + content, err := client.File.GetByPath("1") + require.NoError(t, err) + contentBytes, err := io.ReadAll(content) + require.NoError(t, err) + require.Equal(t, []byte{2}, contentBytes) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := newAppError() + + api.On("ReadFile", "1").Return(nil, appErr) + + content, err := client.File.GetByPath("1") + require.Equal(t, appErr, err) + require.Zero(t, content) + }) +} + +func TestGetFileInfo(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetFileInfo", "1").Return(&model.FileInfo{Id: "2"}, nil) + + info, err := client.File.GetInfo("1") + require.NoError(t, err) + require.Equal(t, "2", info.Id) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := newAppError() + + api.On("GetFileInfo", "1").Return(nil, appErr) + + info, err := client.File.GetInfo("1") + require.Equal(t, appErr, err) + require.Zero(t, info) + }) +} + +func TestGetFileLink(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetFileLink", "1").Return("2", nil) + + link, err := client.File.GetLink("1") + require.NoError(t, err) + require.Equal(t, "2", link) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := newAppError() + + api.On("GetFileLink", "1").Return("", appErr) + + link, err := client.File.GetLink("1") + require.Equal(t, appErr, err) + require.Zero(t, link) + }) +} + +func TestUploadFile(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("UploadFile", []byte{1}, "3", "2").Return(&model.FileInfo{Id: "4"}, nil) + + info, err := client.File.Upload(bytes.NewReader([]byte{1}), "2", "3") + require.NoError(t, err) + require.Equal(t, "4", info.Id) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := newAppError() + + api.On("UploadFile", []byte{1}, "3", "2").Return(nil, appErr) + + info, err := client.File.Upload(bytes.NewReader([]byte{1}), "2", "3") + require.Equal(t, appErr, err) + require.Zero(t, info) + }) +} + +func TestCopyFileInfos(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("CopyFileInfos", "3", []string{"1", "2"}).Return([]string{"4", "5"}, nil) + + newIDs, err := client.File.CopyInfos([]string{"1", "2"}, "3") + require.NoError(t, err) + require.Equal(t, []string{"4", "5"}, newIDs) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := newAppError() + + api.On("CopyFileInfos", "3", []string{"1", "2"}).Return(nil, appErr) + + newIDs, err := client.File.CopyInfos([]string{"1", "2"}, "3") + require.Equal(t, appErr, err) + require.Zero(t, newIDs) + }) +} diff --git a/server/public/pluginapi/frontend.go b/server/public/pluginapi/frontend.go new file mode 100644 index 0000000000..dc94059839 --- /dev/null +++ b/server/public/pluginapi/frontend.go @@ -0,0 +1,30 @@ +package pluginapi + +import ( + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" +) + +// FrontendService exposes methods to interact with the frontend. +type FrontendService struct { + api plugin.API +} + +// OpenInteractiveDialog will open an interactive dialog on a user's client that +// generated the trigger ID. Used with interactive message buttons, menus +// and slash commands. +// +// Minimum server version: 5.6 +func (f *FrontendService) OpenInteractiveDialog(dialog model.OpenDialogRequest) error { + return normalizeAppErr(f.api.OpenInteractiveDialog(dialog)) +} + +// PublishWebSocketEvent sends an event to WebSocket connections. +// event is the type and will be prepended with "custom__". +// payload is the data sent with the event. Interface values must be primitive Go types or mattermost-server/model types. +// broadcast determines to which users to send the event. +// +// Minimum server version: 5.2 +func (f *FrontendService) PublishWebSocketEvent(event string, payload map[string]interface{}, broadcast *model.WebsocketBroadcast) { + f.api.PublishWebSocketEvent(event, payload, broadcast) +} diff --git a/server/public/pluginapi/group.go b/server/public/pluginapi/group.go new file mode 100644 index 0000000000..0aa9f01ca3 --- /dev/null +++ b/server/public/pluginapi/group.go @@ -0,0 +1,57 @@ +package pluginapi + +import ( + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" +) + +// GroupService exposes methods to manipulate groups. +type GroupService struct { + api plugin.API +} + +// Get gets a group by ID. +// +// Minimum server version: 5.18 +func (g *GroupService) Get(groupID string) (*model.Group, error) { + group, appErr := g.api.GetGroup(groupID) + + return group, normalizeAppErr(appErr) +} + +// GetByName gets a group by name. +// +// Minimum server version: 5.18 +func (g *GroupService) GetByName(name string) (*model.Group, error) { + group, appErr := g.api.GetGroupByName(name) + + return group, normalizeAppErr(appErr) +} + +// GetMemberUsers gets a page of users from the given group. +// +// Minimum server version: 5.35 +func (g *GroupService) GetMemberUsers(groupID string, page, perPage int) ([]*model.User, error) { + users, appErr := g.api.GetGroupMemberUsers(groupID, page, perPage) + + return users, normalizeAppErr(appErr) +} + +// GetBySource gets a list of all groups for the given source. +// +// @tag Group +// Minimum server version: 5.35 +func (g *GroupService) GetBySource(groupSource model.GroupSource) ([]*model.Group, error) { + groups, appErr := g.api.GetGroupsBySource(groupSource) + + return groups, normalizeAppErr(appErr) +} + +// ListForUser gets the groups a user is in. +// +// Minimum server version: 5.18 +func (g *GroupService) ListForUser(userID string) ([]*model.Group, error) { + groups, appErr := g.api.GetGroupsForUser(userID) + + return groups, normalizeAppErr(appErr) +} diff --git a/server/public/pluginapi/i18n/doc.go b/server/public/pluginapi/i18n/doc.go new file mode 100644 index 0000000000..aede6b94cc --- /dev/null +++ b/server/public/pluginapi/i18n/doc.go @@ -0,0 +1,2 @@ +// package i18n provides methods to read translations files and localize strings. +package i18n diff --git a/server/public/pluginapi/i18n/i18n.go b/server/public/pluginapi/i18n/i18n.go new file mode 100644 index 0000000000..26447c0adb --- /dev/null +++ b/server/public/pluginapi/i18n/i18n.go @@ -0,0 +1,136 @@ +package i18n + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + + "github.com/nicksnyder/go-i18n/v2/i18n" + "github.com/pkg/errors" + "golang.org/x/text/language" + + "github.com/mattermost/mattermost/server/public/model" +) + +// PluginAPI is the plugin API interface required to manage translations. +type PluginAPI interface { + GetBundlePath() (string, error) + GetConfig() *model.Config + GetUser(userID string) (*model.User, *model.AppError) + LogWarn(msg string, keyValuePairs ...interface{}) +} + +// Message is a string that can be localized. +// +// See https://pkg.go.dev/github.com/nicksnyder/go-i18n/v2/i18n?tab=doc#Message for more details. +type Message = i18n.Message + +// LocalizeConfig configures a call to the Localize method on Localizer. +// +// See https://pkg.go.dev/github.com/nicksnyder/go-i18n/v2/i18n?tab=doc#LocalizeConfig for more details. +type LocalizeConfig = i18n.LocalizeConfig + +// Localizer provides Localize and MustLocalize methods that return localized messages. +// +// See https://pkg.go.dev/github.com/nicksnyder/go-i18n/v2/i18n?tab=doc#Localizer for more details. +type Localizer = i18n.Localizer + +// Bundle stores a set of messages and pluralization rules. +// Most plugins only need a single bundle +// that is initialized on activation. +// It is not goroutine safe to modify the bundle while Localizers +// are reading from it. +type Bundle struct { + *i18n.Bundle + api PluginAPI +} + +// InitBundle loads all localization files from a given path into a bundle and return this. +// path is a relative path in the plugin bundle, e.g. assets/i18n. +// Every file except the ones named active.*.json. +// The default language is English. +func InitBundle(api PluginAPI, path string) (*Bundle, error) { + bundle := &Bundle{ + Bundle: i18n.NewBundle(language.English), + api: api, + } + bundle.RegisterUnmarshalFunc("json", json.Unmarshal) + + bundlePath, err := api.GetBundlePath() + if err != nil { + return nil, errors.Wrap(err, "failed to get bundle path") + } + + i18nDir := filepath.Join(bundlePath, path) + + files, err := os.ReadDir(i18nDir) + if err != nil { + return nil, errors.Wrap(err, "failed to open i18n directory") + } + + for _, file := range files { + if !strings.HasPrefix(file.Name(), "active.") { + continue + } + + if !strings.HasSuffix(file.Name(), ".json") { + continue + } + + if file.Name() == "active.en.json" { + continue + } + + _, err = bundle.LoadMessageFile(filepath.Join(i18nDir, file.Name())) + if err != nil { + return nil, errors.Wrapf(err, "failed to load message file %s", file.Name()) + } + } + + return bundle, nil +} + +// GetUserLocalizer returns a localizer that localizes in the users locale. +func (b *Bundle) GetUserLocalizer(userID string) *i18n.Localizer { + user, err := b.api.GetUser(userID) + if err != nil { + b.api.LogWarn("Failed get user's locale", "error", err.Error()) + return b.GetServerLocalizer() + } + + return i18n.NewLocalizer(b.Bundle, user.Locale) +} + +// GetServerLocalizer returns a localizer that localizes in the default server locale. +// +// This is useful for situations where a messages is shown to every user, +// independent of the users locale. +func (b *Bundle) GetServerLocalizer() *i18n.Localizer { + local := *b.api.GetConfig().LocalizationSettings.DefaultServerLocale + + return i18n.NewLocalizer(b.Bundle, local) +} + +// LocalizeDefaultMessage localizer the provided message. +// An empty string is returned when the localization fails. +func (b *Bundle) LocalizeDefaultMessage(l *Localizer, m *Message) string { + s, err := l.LocalizeMessage(m) + if err != nil { + b.api.LogWarn("Failed to localize message", "message ID", m.ID, "error", err.Error()) + return "" + } + + return s +} + +// LocalizeWithConfig localizer the provided localize config. +// An empty string is returned when the localization fails. +func (b *Bundle) LocalizeWithConfig(l *Localizer, lc *LocalizeConfig) string { + s, err := l.Localize(lc) + if err != nil { + b.api.LogWarn("Failed to localize with config", "error", err.Error()) + return "" + } + return s +} diff --git a/server/public/pluginapi/i18n/i18n_test.go b/server/public/pluginapi/i18n/i18n_test.go new file mode 100644 index 0000000000..aff327d2ea --- /dev/null +++ b/server/public/pluginapi/i18n/i18n_test.go @@ -0,0 +1,296 @@ +package i18n_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "golang.org/x/text/language" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/mattermost/mattermost/server/public/pluginapi/i18n" +) + +//nolint:govet +func ExampleInitBundle() { + type Plugin struct { + plugin.MattermostPlugin + + b *i18n.Bundle + } + + p := Plugin{} + b, err := i18n.InitBundle(p.API, filepath.Join("assets", "i18n")) + if err != nil { + panic(err) + } + + p.b = b +} + +func TestInitBundle(t *testing.T) { + t.Run("fine", func(t *testing.T) { + dir, err := os.MkdirTemp("", "") + require.NoError(t, err) + + defer os.RemoveAll(dir) + + // Create assets/i18n dir + i18nDir := filepath.Join(dir, "assets", "i18n") + err = os.MkdirAll(i18nDir, 0o700) + require.NoError(t, err) + + file := filepath.Join(i18nDir, "active.de.json") + content := []byte("{}") + err = os.WriteFile(file, content, 0o600) + require.NoError(t, err) + + // Add en translation file. + // InitBundle should ignore it. + file = filepath.Join(i18nDir, "active.en.json") + content = []byte("") + err = os.WriteFile(file, content, 0o600) + require.NoError(t, err) + + // Add json junk file + file = filepath.Join(i18nDir, "foo.json") + content = []byte("") + err = os.WriteFile(file, content, 0o600) + require.NoError(t, err) + + // Add active. junk file + file = filepath.Join(i18nDir, "active.foo") + content = []byte("") + err = os.WriteFile(file, content, 0o600) + require.NoError(t, err) + + api := &plugintest.API{} + api.On("GetBundlePath").Return(dir, nil) + defer api.AssertExpectations(t) + + b, err := i18n.InitBundle(api, "assets/i18n") + assert.NoError(t, err) + assert.NotNil(t, b) + + assert.ElementsMatch(t, []language.Tag{language.English, language.German}, b.LanguageTags()) + }) + + t.Run("fine", func(t *testing.T) { + dir, err := os.MkdirTemp("", "") + require.NoError(t, err) + + defer os.RemoveAll(dir) + + // Create assets/i18n dir + i18nDir := filepath.Join(dir, "assets", "i18n") + err = os.MkdirAll(i18nDir, 0o700) + require.NoError(t, err) + + file := filepath.Join(i18nDir, "active.de.json") + content := []byte("{}") + err = os.WriteFile(file, content, 0o600) + require.NoError(t, err) + + // Add translation file with invalid content + file = filepath.Join(i18nDir, "active.es.json") + content = []byte("foo bar") + err = os.WriteFile(file, content, 0o600) + require.NoError(t, err) + + api := &plugintest.API{} + api.On("GetBundlePath").Return(dir, nil) + defer api.AssertExpectations(t) + + b, err := i18n.InitBundle(api, "assets/i18n") + assert.Error(t, err) + assert.Nil(t, b) + }) +} + +func TestLocalizeDefaultMessage(t *testing.T) { + t.Run("fine", func(t *testing.T) { + api := &plugintest.API{} + defaultServerLocale := "en" + api.On("GetConfig").Return(&model.Config{ + LocalizationSettings: model.LocalizationSettings{ + DefaultServerLocale: &defaultServerLocale, + }, + }) + api.On("GetBundlePath").Return(".", nil) + defer api.AssertExpectations(t) + + b, err := i18n.InitBundle(api, ".") + require.NoError(t, err) + + l := b.GetServerLocalizer() + m := &i18n.Message{ + Other: "test message", + } + + assert.Equal(t, m.Other, b.LocalizeDefaultMessage(l, m)) + }) + + t.Run("empty message", func(t *testing.T) { + api := &plugintest.API{} + defaultServerLocale := "en" + api.On("GetConfig").Return(&model.Config{ + LocalizationSettings: model.LocalizationSettings{ + DefaultServerLocale: &defaultServerLocale, + }, + }) + api.On("GetBundlePath").Return(".", nil) + api.On("LogWarn", mock.AnythingOfType("string"), + mock.AnythingOfType("string"), mock.AnythingOfType("string"), + mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return() + defer api.AssertExpectations(t) + + b, err := i18n.InitBundle(api, ".") + require.NoError(t, err) + + l := b.GetServerLocalizer() + m := &i18n.Message{} + + assert.Equal(t, "", b.LocalizeDefaultMessage(l, m)) + }) +} + +func TestLocalizeWithConfig(t *testing.T) { + t.Run("fine", func(t *testing.T) { + api := &plugintest.API{} + defaultServerLocale := "en" + api.On("GetConfig").Return(&model.Config{ + LocalizationSettings: model.LocalizationSettings{ + DefaultServerLocale: &defaultServerLocale, + }, + }) + api.On("GetBundlePath").Return(".", nil) + defer api.AssertExpectations(t) + + b, err := i18n.InitBundle(api, ".") + require.NoError(t, err) + + l := b.GetServerLocalizer() + lc := &i18n.LocalizeConfig{ + DefaultMessage: &i18n.Message{ + Other: "test messsage", + }, + } + + assert.Equal(t, lc.DefaultMessage.Other, b.LocalizeWithConfig(l, lc)) + }) + + t.Run("empty config", func(t *testing.T) { + api := &plugintest.API{} + defaultServerLocale := "en" + api.On("GetConfig").Return(&model.Config{ + LocalizationSettings: model.LocalizationSettings{ + DefaultServerLocale: &defaultServerLocale, + }, + }) + api.On("GetBundlePath").Return(".", nil) + api.On("LogWarn", mock.AnythingOfType("string"), + mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return() + defer api.AssertExpectations(t) + + b, err := i18n.InitBundle(api, ".") + require.NoError(t, err) + + l := b.GetServerLocalizer() + lc := &i18n.LocalizeConfig{} + + assert.Equal(t, "", b.LocalizeWithConfig(l, lc)) + }) + + t.Run("empty message", func(t *testing.T) { + api := &plugintest.API{} + defaultServerLocale := "en" + api.On("GetConfig").Return(&model.Config{ + LocalizationSettings: model.LocalizationSettings{ + DefaultServerLocale: &defaultServerLocale, + }, + }) + api.On("GetBundlePath").Return(".", nil) + api.On("LogWarn", mock.AnythingOfType("string"), + mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return() + defer api.AssertExpectations(t) + + b, err := i18n.InitBundle(api, ".") + require.NoError(t, err) + + l := b.GetServerLocalizer() + lc := &i18n.LocalizeConfig{ + DefaultMessage: &i18n.Message{}, + } + + assert.Equal(t, "", b.LocalizeWithConfig(l, lc)) + }) +} +func TestGetUserLocalizer(t *testing.T) { + t.Run("fine", func(t *testing.T) { + api := &plugintest.API{} + api.On("GetUser", "userID").Return(&model.User{ + Locale: "de", + }, nil) + api.On("GetBundlePath").Return(".", nil) + defer api.AssertExpectations(t) + + b, err := i18n.InitBundle(api, ".") + require.NoError(t, err) + + l := b.GetUserLocalizer("userID") + assert.NotNil(t, l) + + enMessage := &i18n.Message{ + Other: "a", + } + + deMessage := &i18n.Message{ + Other: "b", + } + + err = b.Bundle.AddMessages(language.German, deMessage) + require.NoError(t, err) + + assert.Equal(t, deMessage.Other, b.LocalizeDefaultMessage(l, enMessage)) + }) + + t.Run("error", func(t *testing.T) { + api := &plugintest.API{} + defaultServerLocale := "es" + api.On("GetConfig").Return(&model.Config{ + LocalizationSettings: model.LocalizationSettings{ + DefaultServerLocale: &defaultServerLocale, + }, + }) + api.On("GetBundlePath").Return(".", nil) + api.On("GetUser", "userID").Return(nil, &model.AppError{}) + api.On("LogWarn", mock.AnythingOfType("string"), + mock.AnythingOfType("string"), mock.AnythingOfType("string"), + mock.AnythingOfType("string"), mock.AnythingOfType("string")).Return() + defer api.AssertExpectations(t) + + b, err := i18n.InitBundle(api, ".") + require.NoError(t, err) + + l := b.GetUserLocalizer("userID") + assert.NotNil(t, l) + + enMessage := &i18n.Message{ + Other: "a", + } + + esMessage := &i18n.Message{ + Other: "b", + } + + err = b.Bundle.AddMessages(language.Spanish, esMessage) + require.NoError(t, err) + + assert.Equal(t, esMessage.Other, b.LocalizeDefaultMessage(l, enMessage)) + }) +} diff --git a/server/public/pluginapi/kv.go b/server/public/pluginapi/kv.go new file mode 100644 index 0000000000..ce7f9a5371 --- /dev/null +++ b/server/public/pluginapi/kv.go @@ -0,0 +1,319 @@ +package pluginapi + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/pkg/errors" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" +) + +// numRetries is the number of times the setAtomicWithRetries will retry before returning an error. +const numRetries = 5 + +// KVService exposes methods to read and write key-value pairs for the active plugin. +// +// This service cannot be used to read or write key-value pairs for other plugins. +type KVService struct { + api plugin.API +} + +// TODO: Should this be un exported? +type KVSetOptions struct { + model.PluginKVSetOptions + oldValue interface{} +} + +// KVSetOption is an option passed to Set() operation. +type KVSetOption func(*KVSetOptions) + +// SetAtomic guarantees the write will occur only when the current value of matches the given old +// value. A client is expected to read the old value first, then pass it back to ensure the value +// has not since been modified. +func SetAtomic(oldValue interface{}) KVSetOption { + return func(o *KVSetOptions) { + o.Atomic = true + o.oldValue = oldValue + } +} + +// SetExpiry configures a key value to expire after the given duration relative to now. +func SetExpiry(ttl time.Duration) KVSetOption { + return func(o *KVSetOptions) { + o.ExpireInSeconds = int64(ttl / time.Second) + } +} + +// Set stores a key-value pair, unique per plugin. +// Keys prefixed with `mmi_` are reserved for use by this package and will fail to be set. +// +// Returns (false, err) if DB error occurred +// Returns (false, nil) if the value was not set +// Returns (true, nil) if the value was set +// +// Minimum server version: 5.18 +func (k *KVService) Set(key string, value interface{}, options ...KVSetOption) (bool, error) { + if strings.HasPrefix(key, "mmi_") { + return false, errors.New("'mmi_' prefix is not allowed for keys") + } + + opts := KVSetOptions{} + for _, o := range options { + o(&opts) + } + + var valueBytes []byte + if value != nil { + // Assume JSON encoding, unless explicitly given a byte slice. + var isValueInBytes bool + valueBytes, isValueInBytes = value.([]byte) + if !isValueInBytes { + var err error + valueBytes, err = json.Marshal(value) + if err != nil { + return false, errors.Wrapf(err, "failed to marshal value %v", value) + } + } + } + + downstreamOpts := model.PluginKVSetOptions{ + Atomic: opts.Atomic, + ExpireInSeconds: opts.ExpireInSeconds, + } + + if opts.oldValue != nil { + oldValueBytes, isOldValueInBytes := opts.oldValue.([]byte) + if isOldValueInBytes { + downstreamOpts.OldValue = oldValueBytes + } else { + data, err := json.Marshal(opts.oldValue) + if err != nil { + return false, errors.Wrapf(err, "failed to marshal value %v", opts.oldValue) + } + + downstreamOpts.OldValue = data + } + } + + written, appErr := k.api.KVSetWithOptions(key, valueBytes, downstreamOpts) + return written, normalizeAppErr(appErr) +} + +// SetWithExpiry sets a key-value pair with the given expiration duration relative to now. +// +// Deprecated: SetWithExpiry exists to streamline adoption of this package for existing plugins. +// Use Set with the appropriate options instead. +// +// Minimum server version: 5.18 +func (k *KVService) SetWithExpiry(key string, value interface{}, ttl time.Duration) error { + _, err := k.Set(key, value, SetExpiry(ttl)) + + return err +} + +// CompareAndSet writes a key-value pair if the current value matches the given old value. +// +// Returns (false, err) if DB error occurred +// Returns (false, nil) if the value was not set +// Returns (true, nil) if the value was set +// +// Deprecated: CompareAndSet exists to streamline adoption of this package for existing plugins. +// Use Set with the appropriate options instead. +// +// Minimum server version: 5.18 +func (k *KVService) CompareAndSet(key string, oldValue, value interface{}) (bool, error) { + return k.Set(key, value, SetAtomic(oldValue)) +} + +// CompareAndDelete deletes a key-value pair if the current value matches the given old value. +// +// Returns (false, err) if DB error occurred +// Returns (false, nil) if current value != oldValue or key does not exist when deleting +// Returns (true, nil) if current value == oldValue and the key was deleted +// +// Deprecated: CompareAndDelete exists to streamline adoption of this package for existing plugins. +// Use Set with the appropriate options instead. +// +// Minimum server version: 5.18 +func (k *KVService) CompareAndDelete(key string, oldValue interface{}) (bool, error) { + return k.Set(key, nil, SetAtomic(oldValue)) +} + +// SetAtomicWithRetries will set a key-value pair atomically using compare and set semantics: +// it will read key's value (to get oldValue), perform valueFunc (to get newValue), +// and compare and set (comparing oldValue and setting newValue). +// +// Parameters: +// +// `key` is the key to get and set. +// `valueFunc` is a user-provided function that will take the old value as a []byte and +// return the new value or an error. If valueFunc needs to operate on +// oldValue, it will need to use the oldValue as a []byte, or convert +// oldValue into the expected type (e.g., by parsing it, or marshaling it +// into the expected struct). It should then return the newValue as the type +// expected to be stored. +// +// Returns: +// +// Returns err if the key could not be retrieved (DB error), valueFunc returned an error, +// if the key could not be set (DB error), or if the key could not be set (after retries). +// Returns nil if the value was set. +// +// Minimum server version: 5.18 +func (k *KVService) SetAtomicWithRetries(key string, valueFunc func(oldValue []byte) (newValue interface{}, err error)) error { + for i := 0; i < numRetries; i++ { + var oldVal []byte + if err := k.Get(key, &oldVal); err != nil { + return errors.Wrapf(err, "failed to get value for key %s", key) + } + + newVal, err := valueFunc(oldVal) + if err != nil { + return errors.Wrap(err, "valueFunc failed") + } + + if saved, err := k.Set(key, newVal, SetAtomic(oldVal)); err != nil { + return errors.Wrapf(err, "DB failed to set value for key %s", key) + } else if saved { + return nil + } + + // small delay to allow cooperative scheduling to do its thing + time.Sleep(10 * time.Millisecond) + } + return fmt.Errorf("failed to set value after %d retries", numRetries) +} + +// Get gets the value for the given key into the given interface. +// +// An error is returned only if the value cannot be fetched. A non-existent key will return no +// error, with nothing written to the given interface. +// +// Minimum server version: 5.2 +func (k *KVService) Get(key string, o interface{}) error { + data, appErr := k.api.KVGet(key) + if appErr != nil { + return normalizeAppErr(appErr) + } + + if len(data) == 0 { + return nil + } + + if bytesOut, ok := o.(*[]byte); ok { + *bytesOut = data + return nil + } + + if err := json.Unmarshal(data, o); err != nil { + return errors.Wrapf(err, "failed to unmarshal value for key %s", key) + } + + return nil +} + +// Delete deletes the given key-value pair. +// +// An error is returned only if the value failed to be deleted. A non-existent key will return +// no error. +// +// Minimum server version: 5.18 +func (k *KVService) Delete(key string) error { + _, err := k.Set(key, nil) + return err +} + +// DeleteAll removes all key-value pairs. +// +// Minimum server version: 5.6 +func (k *KVService) DeleteAll() error { + return normalizeAppErr(k.api.KVDeleteAll()) +} + +// ListKeysOption used to configure a ListKeys() operation. +type ListKeysOption func(*listKeysOptions) + +// listKeysOptions holds configurations of a ListKeys() operation. +type listKeysOptions struct { + checkers []func(key string) (keep bool, err error) +} + +func (o *listKeysOptions) checkAll(key string) (keep bool, err error) { + for _, check := range o.checkers { + keep, err := check(key) + if err != nil { + return false, err + } + if !keep { + return false, nil + } + } + + // key made it through all checkers + return true, nil +} + +// WithPrefix only return keys that start with the given string. +func WithPrefix(prefix string) ListKeysOption { + return WithChecker(func(key string) (keep bool, err error) { + return strings.HasPrefix(key, prefix), nil + }) +} + +// WithChecker allows for a custom filter function to determine which keys to return. +// Returning true will keep the key and false will filter it out. Returning an error +// will halt KVListWithOptions immediately and pass the error up (with no other results). +func WithChecker(f func(key string) (keep bool, err error)) ListKeysOption { + return func(args *listKeysOptions) { + args.checkers = append(args.checkers, f) + } +} + +// ListKeys lists all keys that match the given options. If no options are provided then all keys are returned. +// +// Minimum server version: 5.6 +func (k *KVService) ListKeys(page, count int, options ...ListKeysOption) ([]string, error) { + // convert functional options into args struct + args := &listKeysOptions{ + checkers: nil, + } + for _, opt := range options { + opt(args) + } + + // get our keys a batch at a time, filter out the ones we don't want based on our args + // any errors will hault the whole process and return the error raw + + keys, appErr := k.api.KVList(page, count) + if appErr != nil { + return nil, normalizeAppErr(appErr) + } + + if len(args.checkers) == 0 { + // no checkers, just return the keys + return keys, nil + } + + ret := make([]string, 0) + // we have a filter, so check each key, all checkers must say key + // for us to keep a key + for _, key := range keys { + keep, err := args.checkAll(key) + if err != nil { + return nil, err + } + + if !keep { + continue + } + + // didn't get filtered out, add to our return + ret = append(ret, key) + } + + return ret, nil +} diff --git a/server/public/pluginapi/kv_test.go b/server/public/pluginapi/kv_test.go new file mode 100644 index 0000000000..f0e43d9e79 --- /dev/null +++ b/server/public/pluginapi/kv_test.go @@ -0,0 +1,681 @@ +package pluginapi_test + +import ( + "encoding/json" + "errors" + "net/http" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/mattermost/mattermost/server/public/pluginapi" +) + +func newAppError() *model.AppError { + return model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) +} + +func TestKVSet(t *testing.T) { + tests := []struct { + name string + key string + value interface{} + options []pluginapi.KVSetOption + expectedValue []byte + expectedOptions model.PluginKVSetOptions + upserted bool + err error + }{ + { + "[]byte value", + "1", + 2, + []pluginapi.KVSetOption{}, + []byte(`2`), + model.PluginKVSetOptions{}, + true, + nil, + }, { + "string value", + "1", + "2", + []pluginapi.KVSetOption{}, + []byte(`"2"`), + model.PluginKVSetOptions{}, + true, + nil, + }, { + "struct value", + "1", + struct{ A string }{"2"}, + []pluginapi.KVSetOption{}, + []byte(`{"A":"2"}`), + model.PluginKVSetOptions{}, + true, + nil, + }, { + "compare and set []byte value", + "1", + []byte{2}, + []pluginapi.KVSetOption{ + pluginapi.SetAtomic([]byte{3}), + }, + []byte{2}, + model.PluginKVSetOptions{ + Atomic: true, + OldValue: []byte{3}, + }, + true, + nil, + }, { + "compare and set string value", + "1", + "2", + []pluginapi.KVSetOption{ + pluginapi.SetAtomic("3"), + }, + []byte(`"2"`), + model.PluginKVSetOptions{ + Atomic: true, + OldValue: []byte(`"3"`), + }, true, + nil, + }, { + "value is nil", + "1", + nil, + []pluginapi.KVSetOption{}, + nil, + model.PluginKVSetOptions{}, + true, + nil, + }, { + "current value is nil", + "1", + "2", + []pluginapi.KVSetOption{ + pluginapi.SetAtomic(nil), + }, + []byte(`"2"`), + model.PluginKVSetOptions{ + Atomic: true, + OldValue: nil, + }, + true, + nil, + }, { + "value is nil, current value is []byte", + "1", + nil, + []pluginapi.KVSetOption{ + pluginapi.SetAtomic([]byte{3}), + }, + nil, + model.PluginKVSetOptions{ + Atomic: true, + OldValue: []byte{3}, + }, + true, + nil, + }, { + "error", + "1", + []byte{2}, + []pluginapi.KVSetOption{}, + []byte{2}, + model.PluginKVSetOptions{}, + false, + newAppError(), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + api := &plugintest.API{} + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("KVSetWithOptions", test.key, test.expectedValue, test.expectedOptions).Return(test.upserted, test.err) + + upserted, err := client.KV.Set(test.key, test.value, test.options...) + if test.err != nil { + require.Error(t, err, test.name) + require.False(t, upserted, test.name) + } else { + require.NoError(t, err, test.name) + assert.True(t, upserted, test.name) + } + api.AssertExpectations(t) + }) + } +} + +func TestSetWithExpiry(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("KVSetWithOptions", "1", []byte(`2`), model.PluginKVSetOptions{ + ExpireInSeconds: 60, + }).Return(true, nil) + + err := client.KV.SetWithExpiry("1", 2, time.Minute) + require.NoError(t, err) +} + +func TestCompareAndSet(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("KVSetWithOptions", "1", []byte("2"), model.PluginKVSetOptions{ + Atomic: true, + OldValue: []byte("3"), + }).Return(true, nil) + + upserted, err := client.KV.CompareAndSet("1", 3, 2) + require.NoError(t, err) + assert.True(t, upserted) +} + +func TestCompareAndDelete(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("KVSetWithOptions", "1", []byte(nil), model.PluginKVSetOptions{ + Atomic: true, + OldValue: []byte("2"), + }).Return(true, nil) + + deleted, err := client.KV.CompareAndDelete("1", 2) + require.NoError(t, err) + assert.True(t, deleted) +} + +func TestSetAtomicWithRetries(t *testing.T) { + tests := []struct { + name string + key string + valueFunc func(t *testing.T) func(old []byte) (interface{}, error) + setupAPI func(api *plugintest.API) + wantErr bool + expectedErrPrefix string + }{ + { + name: "Test SetAtomicWithRetries success after first attempt", + key: "testNum", + valueFunc: func(t *testing.T) func(old []byte) (interface{}, error) { + return func(old []byte) (interface{}, error) { + return 2, nil + } + }, + setupAPI: func(api *plugintest.API) { + oldJSONBytes, _ := json.Marshal(1) + newJSONBytes, _ := json.Marshal(2) + api.On("KVGet", "testNum").Return(oldJSONBytes, nil) + api.On("KVSetWithOptions", "testNum", newJSONBytes, model.PluginKVSetOptions{ + Atomic: true, + OldValue: oldJSONBytes, + }).Return(true, nil) + }, + }, + { + name: "Test success after first attempt, old is struct and as expected", + key: "testNum2", + valueFunc: func(t *testing.T) func(old []byte) (interface{}, error) { + return func(old []byte) (interface{}, error) { + type toStore struct { + Value int + } + var fromDB toStore + if err := json.Unmarshal(old, &fromDB); err != nil { + return nil, err + } + require.Equal(t, 1, fromDB.Value, "old not as expected") + return toStore{2}, nil + } + }, + setupAPI: func(api *plugintest.API) { + type toStore struct { + Value int + } + oldJSONBytes, _ := json.Marshal(toStore{1}) + newJSONBytes, _ := json.Marshal(toStore{2}) + api.On("KVGet", "testNum2").Return(oldJSONBytes, nil) + api.On("KVSetWithOptions", "testNum2", newJSONBytes, model.PluginKVSetOptions{ + Atomic: true, + OldValue: oldJSONBytes, + }).Return(true, nil) + }, + }, + { + name: "Test success after first attempt, old is an int value and as expected", + key: "testNum2", + valueFunc: func(t *testing.T) func(old []byte) (interface{}, error) { + return func(old []byte) (interface{}, error) { + fromDB, err := strconv.Atoi(string(old)) + if err != nil { + return nil, err + } + require.Equal(t, 1, fromDB, "old not as expected") + return 2, nil + } + }, + setupAPI: func(api *plugintest.API) { + oldJSONBytes, _ := json.Marshal(1) + newJSONBytes, _ := json.Marshal(2) + api.On("KVGet", "testNum2").Return(oldJSONBytes, nil) + api.On("KVSetWithOptions", "testNum2", newJSONBytes, model.PluginKVSetOptions{ + Atomic: true, + OldValue: oldJSONBytes, + }).Return(true, nil) + }, + }, + { + name: "Test SetAtomicWithRetries success on fourth attempt", + key: "testNum", + valueFunc: func(t *testing.T) func(old []byte) (interface{}, error) { + return func(old []byte) (interface{}, error) { + return 2, nil + } + }, + setupAPI: func(api *plugintest.API) { + oldJSONBytes, _ := json.Marshal(1) + newJSONBytes, _ := json.Marshal(2) + api.On("KVGet", "testNum").Return(oldJSONBytes, nil).Times(4) + api.On("KVSetWithOptions", "testNum", newJSONBytes, model.PluginKVSetOptions{ + Atomic: true, + OldValue: oldJSONBytes, + }).Return(false, nil).Times(3) + api.On("KVSetWithOptions", "testNum", newJSONBytes, model.PluginKVSetOptions{ + Atomic: true, + OldValue: oldJSONBytes, + }).Return(true, nil).Once() + }, + }, + { + name: "Test SetAtomicWithRetries success on fourth attempt because value was changed between calls to KVGet", + key: "testNum", + valueFunc: func(t *testing.T) func(old []byte) (interface{}, error) { + return func(old []byte) (interface{}, error) { + return 2, nil + } + }, + setupAPI: func(api *plugintest.API) { + oldJSONBytes, _ := json.Marshal(1) + newJSONBytes, _ := json.Marshal(2) + api.On("KVGet", "testNum").Return(oldJSONBytes, nil).Times(4) + api.On("KVSetWithOptions", "testNum", newJSONBytes, model.PluginKVSetOptions{ + Atomic: true, + OldValue: oldJSONBytes, + }).Return(false, nil).Times(3) + api.On("KVSetWithOptions", "testNum", newJSONBytes, model.PluginKVSetOptions{ + Atomic: true, + OldValue: oldJSONBytes, + }).Return(true, nil).Once() + }, + }, + { + name: "Test SetAtomicWithRetries failure on get", + key: "testNum", + valueFunc: func(t *testing.T) func(old []byte) (interface{}, error) { + return func(old []byte) (interface{}, error) { + return nil, errors.New("should not have got here") + } + }, + setupAPI: func(api *plugintest.API) { + api.On("KVGet", "testNum").Return(nil, newAppError()).Once() + }, + wantErr: true, + expectedErrPrefix: "failed to get value for key testNum", + }, + { + name: "Test SetAtomicWithRetries failure on valueFunc", + key: "testNum", + valueFunc: func(t *testing.T) func(old []byte) (interface{}, error) { + return func(old []byte) (interface{}, error) { + return nil, errors.New("some user provided error") + } + }, + setupAPI: func(api *plugintest.API) { + oldJSONBytes, _ := json.Marshal(1) + api.On("KVGet", "testNum").Return(oldJSONBytes, nil).Once() + }, + wantErr: true, + expectedErrPrefix: "valueFunc failed: some user provided error", + }, + { + name: "Test SetAtomicWithRetries DB failure on set", + key: "testNum", + valueFunc: func(t *testing.T) func(old []byte) (interface{}, error) { + return func(old []byte) (interface{}, error) { + return 2, nil + } + }, + setupAPI: func(api *plugintest.API) { + oldJSONBytes, _ := json.Marshal(1) + newJSONBytes, _ := json.Marshal(2) + api.On("KVGet", "testNum").Return(oldJSONBytes, nil).Once() + api.On("KVSetWithOptions", "testNum", newJSONBytes, model.PluginKVSetOptions{ + Atomic: true, + OldValue: oldJSONBytes, + }).Return(false, newAppError()).Once() + }, + wantErr: true, + expectedErrPrefix: "DB failed to set value for key testNum", + }, + { + name: "Test SetAtomicWithRetries failure on five set attempts -- depends on numRetries constant being = 5", + key: "testNum", + valueFunc: func(t *testing.T) func(old []byte) (interface{}, error) { + return func(old []byte) (interface{}, error) { + return 2, nil + } + }, + setupAPI: func(api *plugintest.API) { + oldJSONBytes, _ := json.Marshal(1) + newJSONBytes, _ := json.Marshal(2) + api.On("KVGet", "testNum").Return(oldJSONBytes, nil).Times(5) + api.On("KVSetWithOptions", "testNum", newJSONBytes, model.PluginKVSetOptions{ + Atomic: true, + OldValue: oldJSONBytes, + }).Return(false, nil).Times(5) + }, + wantErr: true, + expectedErrPrefix: "failed to set value after 5 retries", + }, + { + name: "Test SetAtomicWithRetries success after five set attempts -- depends on numRetries constant being = 5", + key: "testNum", + valueFunc: func(t *testing.T) func(old []byte) (interface{}, error) { + return func(old []byte) (interface{}, error) { + fromDB, err := strconv.Atoi(string(old)) + if err != nil { + return nil, err + } + return fromDB + 1, nil + } + }, + setupAPI: func(api *plugintest.API) { + i1, _ := json.Marshal(1) + i2, _ := json.Marshal(2) + i3, _ := json.Marshal(3) + i4, _ := json.Marshal(4) + i5, _ := json.Marshal(5) + i6, _ := json.Marshal(6) + api.On("KVGet", "testNum").Return(i1, nil).Once() + api.On("KVSetWithOptions", "testNum", i2, model.PluginKVSetOptions{ + Atomic: true, + OldValue: i1, + }).Return(false, nil).Once() + api.On("KVGet", "testNum").Return(i2, nil).Once() + api.On("KVSetWithOptions", "testNum", i3, model.PluginKVSetOptions{ + Atomic: true, + OldValue: i2, + }).Return(false, nil).Once() + api.On("KVGet", "testNum").Return(i3, nil).Once() + api.On("KVSetWithOptions", "testNum", i4, model.PluginKVSetOptions{ + Atomic: true, + OldValue: i3, + }).Return(false, nil).Once() + api.On("KVGet", "testNum").Return(i4, nil).Once() + api.On("KVSetWithOptions", "testNum", i5, model.PluginKVSetOptions{ + Atomic: true, + OldValue: i4, + }).Return(false, nil).Once() + api.On("KVGet", "testNum").Return(i5, nil).Once() + api.On("KVSetWithOptions", "testNum", i6, model.PluginKVSetOptions{ + Atomic: true, + OldValue: i5, + }).Return(true, nil).Once() + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + tt.setupAPI(api) + + err := client.KV.SetAtomicWithRetries(tt.key, tt.valueFunc(t)) + if tt.wantErr { + if err == nil { + t.Errorf("SetAtomicWithRetries() error = %v, wantErr %v", err, tt.wantErr) + } + if !strings.HasPrefix(err.Error(), tt.expectedErrPrefix) { + t.Errorf("SetAtomicWithRetries() error = %s, expected prefix = %s", err, tt.expectedErrPrefix) + } + } + }) + } +} + +func TestGet(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + aStringJSON, _ := json.Marshal("2") + + api.On("KVGet", "1").Return(aStringJSON, nil) + + var out string + err := client.KV.Get("1", &out) + require.NoError(t, err) + assert.Equal(t, "2", out) +} + +func TestGetNilKey(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("KVGet", "1").Return(nil, nil) + + var out string + err := client.KV.Get("1", &out) + require.NoError(t, err) + assert.Empty(t, out) +} + +func TestGetInBytes(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("KVGet", "1").Return([]byte{2}, nil) + + var out []byte + err := client.KV.Get("1", &out) + require.NoError(t, err) + assert.Equal(t, []byte{2}, out) + api.AssertExpectations(t) +} + +func TestDelete(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("KVSetWithOptions", "1", []byte(nil), model.PluginKVSetOptions{}).Return(true, nil) + + err := client.KV.Delete("1") + require.NoError(t, err) +} + +func TestDeleteAll(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("KVDeleteAll").Return(nil) + + err := client.KV.DeleteAll() + require.NoError(t, err) +} + +func TestListKeys(t *testing.T) { + t.Run("No keys", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("KVList", 0, 100).Return(nil, nil) + + keys, err := client.KV.ListKeys(0, 100) + + assert.Empty(t, keys) + assert.NoError(t, err) + }) + + t.Run("Basic Success, one page", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("KVList", 1, 2).Return(getKeys(2), nil) + + keys, err := client.KV.ListKeys(1, 2) + require.NoError(t, err) + require.Equal(t, getKeys(2), keys) + }) + + t.Run("success, two page, filter prefix, one", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("KVList", 0, 100).Return(getKeys(100), nil) + + keys, err := client.KV.ListKeys(0, 100, pluginapi.WithPrefix("key99")) + assert.ElementsMatch(t, keys, []string{"key99"}) + assert.NoError(t, err) + }) + + t.Run("success, two page, filter prefix, all", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("KVList", 0, 100).Return(getKeys(100), nil) + + keys, err := client.KV.ListKeys(0, 100, pluginapi.WithPrefix("notkey")) + assert.Empty(t, keys) + assert.NoError(t, err) + }) + + t.Run("success, two page, filter prefix, none", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("KVList", 0, 100).Return(getKeys(100), nil) + + keys, err := client.KV.ListKeys(0, 100, pluginapi.WithPrefix("key")) + assert.ElementsMatch(t, keys, getKeys(100)) + assert.NoError(t, err) + }) + + t.Run("success, two page, checker func, one", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("KVList", 0, 100).Return(getKeys(100), nil) + + check := func(key string) (bool, error) { + if key == "key1" { + return true, nil + } + return false, nil + } + + keys, err := client.KV.ListKeys(0, 100, pluginapi.WithChecker(check)) + assert.ElementsMatch(t, keys, []string{"key1"}) + assert.NoError(t, err) + }) + + t.Run("success, two page, checker func, all", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("KVList", 0, 100).Return(getKeys(100), nil) + + check := func(key string) (bool, error) { + return false, nil + } + + keys, err := client.KV.ListKeys(0, 100, pluginapi.WithChecker(check)) + assert.Empty(t, keys) + assert.NoError(t, err) + }) + + t.Run("success, two page, checker func, none", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("KVList", 0, 100).Return(getKeys(100), nil) + + check := func(key string) (bool, error) { + return true, nil + } + + keys, err := client.KV.ListKeys(0, 100, pluginapi.WithChecker(check)) + assert.ElementsMatch(t, keys, getKeys(100)) + assert.NoError(t, err) + }) + + t.Run("error, checker func", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("KVList", 0, 100).Return([]string{"key1"}, nil) + + check := func(key string) (bool, error) { + return true, &model.AppError{} + } + + keys, err := client.KV.ListKeys(0, 100, pluginapi.WithChecker(check)) + assert.Empty(t, keys) + assert.Error(t, err) + }) + + t.Run("success, filter and checker func, partial on both", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("KVList", 0, 100).Return([]string{"key1", "key2", "notkey3", "key4", "key5"}, nil) + + check := func(key string) (bool, error) { + if key == "key1" || key == "key5" { + return false, nil + } + return true, nil + } + + keys, err := client.KV.ListKeys(0, 100, pluginapi.WithPrefix("key"), pluginapi.WithChecker(check)) + assert.ElementsMatch(t, keys, []string{"key2", "key4"}) + assert.NoError(t, err) + }) +} + +func getKeys(count int) []string { + ret := make([]string, count) + for i := 0; i < count; i++ { + ret[i] = "key" + strconv.Itoa(i) + } + return ret +} diff --git a/server/public/pluginapi/license.go b/server/public/pluginapi/license.go new file mode 100644 index 0000000000..f332280cef --- /dev/null +++ b/server/public/pluginapi/license.go @@ -0,0 +1,107 @@ +package pluginapi + +import ( + "github.com/mattermost/mattermost/server/public/model" +) + +const ( + e10 = "E10" + e20 = "E20" + professional = "professional" + enterprise = "enterprise" +) + +// IsEnterpriseLicensedOrDevelopment returns true when the server is licensed with any Mattermost +// Enterprise License, or has `EnableDeveloper` and `EnableTesting` configuration settings +// enabled signaling a non-production, developer mode. +func IsEnterpriseLicensedOrDevelopment(config *model.Config, license *model.License) bool { + if license != nil { + return true + } + + return IsConfiguredForDevelopment(config) +} + +// isValidSkuShortName returns whether the SKU short name is one of the known strings; +// namely: E10 or professional, or E20 or enterprise +func isValidSkuShortName(license *model.License) bool { + if license == nil { + return false + } + + switch license.SkuShortName { + case e10, e20, professional, enterprise: + return true + default: + return false + } +} + +// IsE10LicensedOrDevelopment returns true when the server is at least licensed with a legacy Mattermost +// Enterprise E10 License or a Mattermost Professional License, or has `EnableDeveloper` and +// `EnableTesting` configuration settings enabled, signaling a non-production, developer mode. +func IsE10LicensedOrDevelopment(config *model.Config, license *model.License) bool { + if license != nil && + (license.SkuShortName == e10 || license.SkuShortName == professional || + license.SkuShortName == e20 || license.SkuShortName == enterprise) { + return true + } + + if !isValidSkuShortName(license) { + // As a fallback for licenses whose SKU short name is unknown, make a best effort to try + // and use the presence of a known E10/Professional feature as a check to determine licensing. + if license != nil && + license.Features != nil && + license.Features.LDAP != nil && + *license.Features.LDAP { + return true + } + } + + return IsConfiguredForDevelopment(config) +} + +// IsE20LicensedOrDevelopment returns true when the server is licensed with a legacy Mattermost +// Enterprise E20 License or a Mattermost Enterprise License, or has `EnableDeveloper` and +// `EnableTesting` configuration settings enabled, signaling a non-production, developer mode. +func IsE20LicensedOrDevelopment(config *model.Config, license *model.License) bool { + if license != nil && (license.SkuShortName == e20 || license.SkuShortName == enterprise) { + return true + } + + if !isValidSkuShortName(license) { + // As a fallback for licenses whose SKU short name is unknown, make a best effort to try + // and use the presence of a known E20/Enterprise feature as a check to determine licensing. + if license != nil && + license.Features != nil && + license.Features.FutureFeatures != nil && + *license.Features.FutureFeatures { + return true + } + } + + return IsConfiguredForDevelopment(config) +} + +// IsConfiguredForDevelopment returns true when the server has `EnableDeveloper` and `EnableTesting` +// configuration settings enabled, signaling a non-production, developer mode. +func IsConfiguredForDevelopment(config *model.Config) bool { + if config != nil && + config.ServiceSettings.EnableTesting != nil && + *config.ServiceSettings.EnableTesting && + config.ServiceSettings.EnableDeveloper != nil && + *config.ServiceSettings.EnableDeveloper { + return true + } + + return false +} + +// IsCloud returns true when the server is on cloud, and false otherwise. +func IsCloud(license *model.License) bool { + if license == nil || license.Features == nil || license.Features.Cloud == nil { + return false + } + + return *license.Features.Cloud +} diff --git a/server/public/pluginapi/license_test.go b/server/public/pluginapi/license_test.go new file mode 100644 index 0000000000..55835e9711 --- /dev/null +++ b/server/public/pluginapi/license_test.go @@ -0,0 +1,330 @@ +package pluginapi + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/mattermost/mattermost/server/public/model" +) + +func TestIsEnterpriseLicensedOrDevelopment(t *testing.T) { + t.Run("license, no config", func(t *testing.T) { + assert.True(t, IsEnterpriseLicensedOrDevelopment(nil, &model.License{})) + }) + + t.Run("license, nil config", func(t *testing.T) { + assert.True(t, IsEnterpriseLicensedOrDevelopment( + &model.Config{ServiceSettings: model.ServiceSettings{EnableDeveloper: nil, EnableTesting: nil}}, + &model.License{}, + )) + }) + + t.Run("no license, no config", func(t *testing.T) { + assert.False(t, IsEnterpriseLicensedOrDevelopment(nil, nil)) + }) + + t.Run("no license, nil config", func(t *testing.T) { + assert.False(t, IsEnterpriseLicensedOrDevelopment( + &model.Config{ServiceSettings: model.ServiceSettings{EnableDeveloper: nil, EnableTesting: nil}}, + nil, + )) + }) + + t.Run("no license, only developer mode", func(t *testing.T) { + assert.False(t, IsEnterpriseLicensedOrDevelopment( + &model.Config{ServiceSettings: model.ServiceSettings{EnableDeveloper: bToP(true), EnableTesting: bToP(false)}}, + nil, + )) + }) + + t.Run("no license, only testing mode", func(t *testing.T) { + assert.False(t, IsEnterpriseLicensedOrDevelopment( + &model.Config{ServiceSettings: model.ServiceSettings{EnableDeveloper: bToP(false), EnableTesting: bToP(true)}}, + nil, + )) + }) + + t.Run("no license, developer and testing mode", func(t *testing.T) { + assert.True(t, IsEnterpriseLicensedOrDevelopment( + &model.Config{ServiceSettings: model.ServiceSettings{EnableDeveloper: bToP(true), EnableTesting: bToP(true)}}, + nil, + )) + }) +} + +func TestIsE20LicensedOrDevelopment(t *testing.T) { + t.Run("nil license features", func(t *testing.T) { + assert.False(t, IsE20LicensedOrDevelopment(nil, &model.License{})) + }) + + t.Run("nil future features", func(t *testing.T) { + assert.False(t, IsE20LicensedOrDevelopment(nil, &model.License{Features: &model.Features{}})) + }) + + t.Run("disabled future features", func(t *testing.T) { + assert.False(t, IsE20LicensedOrDevelopment(nil, &model.License{Features: &model.Features{ + FutureFeatures: bToP(false), + }})) + }) + + t.Run("enabled future features", func(t *testing.T) { + assert.True(t, IsE20LicensedOrDevelopment(nil, &model.License{Features: &model.Features{ + FutureFeatures: bToP(true), + }})) + }) + + t.Run("no license, no config", func(t *testing.T) { + assert.False(t, IsE20LicensedOrDevelopment(nil, nil)) + }) + + t.Run("no license, nil config", func(t *testing.T) { + assert.False(t, IsE20LicensedOrDevelopment( + &model.Config{ServiceSettings: model.ServiceSettings{EnableDeveloper: nil, EnableTesting: nil}}, + nil, + )) + }) + + t.Run("no license, only developer mode", func(t *testing.T) { + assert.False(t, IsE20LicensedOrDevelopment( + &model.Config{ServiceSettings: model.ServiceSettings{EnableDeveloper: bToP(true), EnableTesting: bToP(false)}}, + nil, + )) + }) + + t.Run("no license, only testing mode", func(t *testing.T) { + assert.False(t, IsE20LicensedOrDevelopment( + &model.Config{ServiceSettings: model.ServiceSettings{EnableDeveloper: bToP(false), EnableTesting: bToP(true)}}, + nil, + )) + }) + + t.Run("no license, developer and testing mode", func(t *testing.T) { + assert.True(t, IsE20LicensedOrDevelopment( + &model.Config{ServiceSettings: model.ServiceSettings{EnableDeveloper: bToP(true), EnableTesting: bToP(true)}}, + nil, + )) + }) + + t.Run("license with E10 SKU name, disabled future features", func(t *testing.T) { + assert.False(t, IsE20LicensedOrDevelopment(nil, &model.License{ + SkuShortName: "E10", + Features: &model.Features{FutureFeatures: bToP(false)}, + })) + }) + + t.Run("license with E10 SKU name, enabled future features", func(t *testing.T) { + assert.False(t, IsE20LicensedOrDevelopment(nil, &model.License{ + SkuShortName: "E10", + Features: &model.Features{FutureFeatures: bToP(true)}, + })) + }) + + t.Run("license with professional SKU name, disabled future features", func(t *testing.T) { + assert.False(t, IsE20LicensedOrDevelopment(nil, &model.License{ + SkuShortName: "professional", + Features: &model.Features{FutureFeatures: bToP(false)}, + })) + }) + + t.Run("license with professional SKU name, enabled future features", func(t *testing.T) { + assert.False(t, IsE20LicensedOrDevelopment(nil, &model.License{ + SkuShortName: "professional", + Features: &model.Features{FutureFeatures: bToP(true)}, + })) + }) + t.Run("license with E20 SKU name, disabled future features", func(t *testing.T) { + assert.True(t, IsE20LicensedOrDevelopment(nil, &model.License{ + SkuShortName: "E20", + Features: &model.Features{FutureFeatures: bToP(false)}, + })) + }) + + t.Run("license with E20 SKU name, enabled future features", func(t *testing.T) { + assert.True(t, IsE20LicensedOrDevelopment(nil, &model.License{ + SkuShortName: "E20", + Features: &model.Features{FutureFeatures: bToP(true)}, + })) + }) + + t.Run("license with enterprise SKU name, disabled future features", func(t *testing.T) { + assert.True(t, IsE20LicensedOrDevelopment(nil, &model.License{ + SkuShortName: "enterprise", + Features: &model.Features{FutureFeatures: bToP(false)}, + })) + }) + + t.Run("license with enterprise SKU name, enabled future features", func(t *testing.T) { + assert.True(t, IsE20LicensedOrDevelopment(nil, &model.License{ + SkuShortName: "enterprise", + Features: &model.Features{FutureFeatures: bToP(true)}, + })) + }) + + t.Run("license with unknown SKU name, disabled future features", func(t *testing.T) { + assert.False(t, IsE20LicensedOrDevelopment(nil, &model.License{ + SkuShortName: "unknown", + Features: &model.Features{FutureFeatures: bToP(false)}, + })) + }) + + t.Run("license with unknown SKU name, enabled future features", func(t *testing.T) { + assert.True(t, IsE20LicensedOrDevelopment(nil, &model.License{ + SkuShortName: "unknown", + Features: &model.Features{FutureFeatures: bToP(true)}, + })) + }) +} + +func TestIsE10LicensedOrDevelopment(t *testing.T) { + t.Run("nil license features", func(t *testing.T) { + assert.False(t, IsE10LicensedOrDevelopment(nil, &model.License{})) + }) + + t.Run("nil future features", func(t *testing.T) { + assert.False(t, IsE10LicensedOrDevelopment(nil, &model.License{Features: &model.Features{}})) + }) + + t.Run("disabled LDAP", func(t *testing.T) { + assert.False(t, IsE10LicensedOrDevelopment(nil, &model.License{Features: &model.Features{ + LDAP: bToP(false), + }})) + }) + + t.Run("enabled LDAP", func(t *testing.T) { + assert.True(t, IsE10LicensedOrDevelopment(nil, &model.License{Features: &model.Features{ + LDAP: bToP(true), + }})) + }) + + t.Run("no license, no config", func(t *testing.T) { + assert.False(t, IsE10LicensedOrDevelopment(nil, nil)) + }) + + t.Run("no license, nil config", func(t *testing.T) { + assert.False(t, IsE10LicensedOrDevelopment( + &model.Config{ServiceSettings: model.ServiceSettings{EnableDeveloper: nil, EnableTesting: nil}}, + nil, + )) + }) + + t.Run("no license, only developer mode", func(t *testing.T) { + assert.False(t, IsE10LicensedOrDevelopment( + &model.Config{ServiceSettings: model.ServiceSettings{EnableDeveloper: bToP(true), EnableTesting: bToP(false)}}, + nil, + )) + }) + + t.Run("no license, only testing mode", func(t *testing.T) { + assert.False(t, IsE10LicensedOrDevelopment( + &model.Config{ServiceSettings: model.ServiceSettings{EnableDeveloper: bToP(false), EnableTesting: bToP(true)}}, + nil, + )) + }) + + t.Run("no license, developer and testing mode", func(t *testing.T) { + assert.True(t, IsE10LicensedOrDevelopment( + &model.Config{ServiceSettings: model.ServiceSettings{EnableDeveloper: bToP(true), EnableTesting: bToP(true)}}, + nil, + )) + }) + + t.Run("license with E10 SKU name, disabled LDAP", func(t *testing.T) { + assert.True(t, IsE10LicensedOrDevelopment(nil, &model.License{ + SkuShortName: "E10", + Features: &model.Features{LDAP: bToP(false)}, + })) + }) + + t.Run("license with E10 SKU name, enabled LDAP", func(t *testing.T) { + assert.True(t, IsE10LicensedOrDevelopment(nil, &model.License{ + SkuShortName: "E10", + Features: &model.Features{LDAP: bToP(true)}, + })) + }) + + t.Run("license with professional SKU name, disabled LDAP", func(t *testing.T) { + assert.True(t, IsE10LicensedOrDevelopment(nil, &model.License{ + SkuShortName: "professional", + Features: &model.Features{LDAP: bToP(false)}, + })) + }) + + t.Run("license with professional SKU name, enabled LDAP", func(t *testing.T) { + assert.True(t, IsE10LicensedOrDevelopment(nil, &model.License{ + SkuShortName: "professional", + Features: &model.Features{LDAP: bToP(true)}, + })) + }) + t.Run("license with E20 SKU name, disabled LDAP", func(t *testing.T) { + assert.True(t, IsE10LicensedOrDevelopment(nil, &model.License{ + SkuShortName: "E20", + Features: &model.Features{LDAP: bToP(false)}, + })) + }) + + t.Run("license with E20 SKU name, enabled LDAP", func(t *testing.T) { + assert.True(t, IsE10LicensedOrDevelopment(nil, &model.License{ + SkuShortName: "E20", + Features: &model.Features{LDAP: bToP(true)}, + })) + }) + + t.Run("license with enterprise SKU name, disabled LDAP", func(t *testing.T) { + assert.True(t, IsE10LicensedOrDevelopment(nil, &model.License{ + SkuShortName: "enterprise", + Features: &model.Features{LDAP: bToP(false)}, + })) + }) + + t.Run("license with enterprise SKU name, enabled LDAP", func(t *testing.T) { + assert.True(t, IsE10LicensedOrDevelopment(nil, &model.License{ + SkuShortName: "enterprise", + Features: &model.Features{LDAP: bToP(true)}, + })) + }) + + t.Run("license with unknown SKU name, disabled LDAP", func(t *testing.T) { + assert.False(t, IsE10LicensedOrDevelopment(nil, &model.License{ + SkuShortName: "unknown", + Features: &model.Features{LDAP: bToP(false)}, + })) + }) + + t.Run("license with unknown SKU name, enabled LDAP", func(t *testing.T) { + assert.True(t, IsE10LicensedOrDevelopment(nil, &model.License{ + SkuShortName: "unknown", + Features: &model.Features{LDAP: bToP(true)}, + })) + }) +} + +func TestIsValidSKUShortName(t *testing.T) { + t.Run("nil license", func(t *testing.T) { + assert.False(t, isValidSkuShortName(nil)) + }) + + t.Run("license with valid E10 SKU name", func(t *testing.T) { + assert.True(t, isValidSkuShortName(&model.License{SkuShortName: "E10"})) + }) + + t.Run("license with valid E20 SKU name", func(t *testing.T) { + assert.True(t, isValidSkuShortName(&model.License{SkuShortName: "E20"})) + }) + + t.Run("license with valid professional SKU name", func(t *testing.T) { + assert.True(t, isValidSkuShortName(&model.License{SkuShortName: "professional"})) + }) + + t.Run("license with valid enterprise SKU name", func(t *testing.T) { + assert.True(t, isValidSkuShortName(&model.License{SkuShortName: "enterprise"})) + }) + + t.Run("license with invalid SKU name", func(t *testing.T) { + assert.False(t, isValidSkuShortName(&model.License{SkuShortName: "invalid"})) + }) +} + +func bToP(b bool) *bool { + return &b +} diff --git a/server/public/pluginapi/log.go b/server/public/pluginapi/log.go new file mode 100644 index 0000000000..34184b6602 --- /dev/null +++ b/server/public/pluginapi/log.go @@ -0,0 +1,33 @@ +package pluginapi + +import ( + "github.com/mattermost/mattermost/server/public/plugin" +) + +// LogService exposes methods to log to the Mattermost server log. +// +// Note that standard error is automatically sent to the Mattermost server log, and standard +// output is redirected to standard error. This service enables optional structured logging. +type LogService struct { + api plugin.API +} + +// Error logs an error message, optionally structured with alternating key, value parameters. +func (l *LogService) Error(message string, keyValuePairs ...interface{}) { + l.api.LogError(message, keyValuePairs...) +} + +// Warn logs an error message, optionally structured with alternating key, value parameters. +func (l *LogService) Warn(message string, keyValuePairs ...interface{}) { + l.api.LogWarn(message, keyValuePairs...) +} + +// Info logs an error message, optionally structured with alternating key, value parameters. +func (l *LogService) Info(message string, keyValuePairs ...interface{}) { + l.api.LogInfo(message, keyValuePairs...) +} + +// Debug logs an error message, optionally structured with alternating key, value parameters. +func (l *LogService) Debug(message string, keyValuePairs ...interface{}) { + l.api.LogDebug(message, keyValuePairs...) +} diff --git a/server/public/pluginapi/logrus.go b/server/public/pluginapi/logrus.go new file mode 100644 index 0000000000..37999d012f --- /dev/null +++ b/server/public/pluginapi/logrus.go @@ -0,0 +1,69 @@ +package pluginapi + +import ( + "fmt" + "io" + + "github.com/sirupsen/logrus" +) + +// LogrusHook is a logrus.Hook for emitting plugin logs through the RPC API for inclusion in the +// server logs. +// +// To configure the default Logrus logger for use with plugin logging, simply invoke: +// +// pluginapi.ConfigureLogrus(logrus.StandardLogger(), pluginAPIClient) +// +// Alternatively, construct your own logger to pass to pluginapi.ConfigureLogrus. +type LogrusHook struct { + log LogService +} + +// NewLogrusHook creates a new instance of LogrusHook. +func NewLogrusHook(log LogService) *LogrusHook { + return &LogrusHook{ + log: log, + } +} + +// Levels allows LogrusHook to process any log level. +func (lh *LogrusHook) Levels() []logrus.Level { + return logrus.AllLevels +} + +// Fire proxies logrus entries through the plugin API at the appropriate level. +func (lh *LogrusHook) Fire(entry *logrus.Entry) error { + fields := []interface{}{} + for key, value := range entry.Data { + fields = append(fields, key, fmt.Sprintf("%+v", value)) + } + + if entry.Caller != nil { + fields = append(fields, "plugin_caller", fmt.Sprintf("%s:%d", entry.Caller.File, entry.Caller.Line)) + } + + switch entry.Level { + case logrus.PanicLevel, logrus.FatalLevel, logrus.ErrorLevel: + lh.log.Error(entry.Message, fields...) + case logrus.WarnLevel: + lh.log.Warn(entry.Message, fields...) + case logrus.InfoLevel: + lh.log.Info(entry.Message, fields...) + case logrus.DebugLevel, logrus.TraceLevel: + lh.log.Debug(entry.Message, fields...) + } + + return nil +} + +// ConfigureLogrus configures the given logrus logger with a hook to proxy through the RPC API, +// discarding the default output to avoid duplicating the events across the standard STDOUT proxy. +func ConfigureLogrus(logger *logrus.Logger, client *Client) { + hook := NewLogrusHook(client.Log) + logger.Hooks.Add(hook) + logger.SetOutput(io.Discard) + logrus.SetReportCaller(true) + + // By default, log everything to the server, and let it decide what gets through. + logrus.SetLevel(logrus.TraceLevel) +} diff --git a/server/public/pluginapi/logrus_test.go b/server/public/pluginapi/logrus_test.go new file mode 100644 index 0000000000..04c5481a58 --- /dev/null +++ b/server/public/pluginapi/logrus_test.go @@ -0,0 +1,85 @@ +package pluginapi_test + +import ( + "testing" + + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/mattermost/mattermost/server/public/pluginapi" +) + +func TestLogrus(t *testing.T) { + testCases := []struct { + Level logrus.Level + APICall string + }{ + {logrus.PanicLevel, "LogError"}, + {logrus.FatalLevel, "LogError"}, + {logrus.ErrorLevel, "LogError"}, + {logrus.WarnLevel, "LogWarn"}, + {logrus.InfoLevel, "LogInfo"}, + {logrus.DebugLevel, "LogDebug"}, + {logrus.TraceLevel, "LogDebug"}, + } + + for _, testCase := range testCases { + t.Run(testCase.Level.String(), func(t *testing.T) { + logger := logrus.New() + logger.SetLevel(logrus.TraceLevel) // not testing logrus filtering + logger.ReportCaller = true + + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + pluginapi.ConfigureLogrus(logger, client) + + // Parameter order of map is non-deterministic, so expect either. + api.On(testCase.APICall, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything) + + entry := logger.WithFields(logrus.Fields{ + "a": "a", + "b": 1, + }) + + if testCase.Level == logrus.PanicLevel { + done := make(chan bool) + go func() { + defer func() { + r := recover() + assert.NotNil(t, r, "expected panic") + close(done) + }() + + entry.Panic("message") + }() + <-done + } else { + entry.Log(testCase.Level, "message") + } + + // Assert the required API call was executed at most once. + if api.AssertNumberOfCalls(t, testCase.APICall, 1) { + call := api.Calls[0] + for i := 1; i < len(call.Arguments)-1; i += 2 { + argument := call.Arguments[i] + value := call.Arguments[i+1] + + switch argument { + case "a": + assert.Equal(t, "a", value, "unexpected value for a") + case "b": + assert.Equal(t, "1", value, "unexpected value for b") + case "plugin_caller": + assert.IsType(t, "string", value) + default: + assert.Fail(t, "unexpected argument and value", "%v: %v", argument, value) + } + } + } + }) + } +} diff --git a/server/public/pluginapi/oauth.go b/server/public/pluginapi/oauth.go new file mode 100644 index 0000000000..ef3dd8ee87 --- /dev/null +++ b/server/public/pluginapi/oauth.go @@ -0,0 +1,55 @@ +package pluginapi + +import ( + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" +) + +// UserService exposes methods to manipulate OAuth Apps. +type OAuthService struct { + api plugin.API +} + +// Create creates a new OAuth App. +// +// Minimum server version: 5.38 +func (o *OAuthService) Create(app *model.OAuthApp) error { + createdApp, appErr := o.api.CreateOAuthApp(app) + if appErr != nil { + return normalizeAppErr(appErr) + } + + *app = *createdApp + + return nil +} + +// Get gets an existing OAuth App by id. +// +// Minimum server version: 5.38 +func (o *OAuthService) Get(appID string) (*model.OAuthApp, error) { + app, appErr := o.api.GetOAuthApp(appID) + + return app, normalizeAppErr(appErr) +} + +// Update updates an existing OAuth App. +// +// Minimum server version: 5.38 +func (o *OAuthService) Update(app *model.OAuthApp) error { + updatedApp, appErr := o.api.UpdateOAuthApp(app) + if appErr != nil { + return normalizeAppErr(appErr) + } + + *app = *updatedApp + + return nil +} + +// Delete deletes an existing OAuth App by id. +// +// Minimum server version: 5.38 +func (o *OAuthService) Delete(appID string) error { + return normalizeAppErr(o.api.DeleteOAuthApp(appID)) +} diff --git a/server/public/pluginapi/plugin_test.go b/server/public/pluginapi/plugin_test.go new file mode 100644 index 0000000000..98948cc906 --- /dev/null +++ b/server/public/pluginapi/plugin_test.go @@ -0,0 +1,170 @@ +package pluginapi_test + +import ( + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/mattermost/mattermost/server/public/plugin/plugintest/mock" + "github.com/mattermost/mattermost/server/public/pluginapi" +) + +func TestInstallPluginFromURL(t *testing.T) { + replace := true + + t.Run("incompatible server version", func(t *testing.T) { + api := &plugintest.API{} + api.On("GetServerVersion").Return("5.1.0") + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + _, err := client.Plugin.InstallPluginFromURL("", true) + + assert.Error(t, err) + assert.Equal(t, "incompatible server version for plugin, minimum required version: 5.18.0, current version: 5.1.0", err.Error()) + }) + + t.Run("error while parsing the download url", func(t *testing.T) { + api := &plugintest.API{} + api.On("GetServerVersion").Return("5.19.0") + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + _, err := client.Plugin.InstallPluginFromURL("http://%41:8080/", replace) + + assert.Error(t, err) + assert.Equal(t, "error while parsing url: parse \"http://%41:8080/\": invalid URL escape \"%41\"", err.Error()) + }) + + t.Run("errors out while downloading file", func(t *testing.T) { + api := &plugintest.API{} + api.On("GetServerVersion").Return("5.19.0") + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + testServer := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + res.WriteHeader(http.StatusInternalServerError) + })) + defer testServer.Close() + url := testServer.URL + + _, err := client.Plugin.InstallPluginFromURL(url, replace) + + assert.Error(t, err) + assert.Equal(t, "received 500 status code while downloading plugin from server", err.Error()) + }) + + t.Run("downloads the file successfully", func(t *testing.T) { + api := &plugintest.API{} + api.On("GetServerVersion").Return("5.19.0") + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + tarData, err := os.ReadFile(filepath.Join("../../../tests", "testplugin.tar.gz")) + require.NoError(t, err) + expectedManifest := &model.Manifest{Id: "testplugin"} + api.On("InstallPlugin", mock.Anything, false).Return(expectedManifest, nil) + + testServer := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + res.WriteHeader(http.StatusOK) + _, _ = res.Write(tarData) + })) + defer testServer.Close() + url := testServer.URL + + manifest, err := client.Plugin.InstallPluginFromURL(url, false) + + assert.NoError(t, err) + assert.Equal(t, "testplugin", manifest.Id) + }) + + t.Run("the url pointing to server is incorrect", func(t *testing.T) { + api := &plugintest.API{} + api.On("GetServerVersion").Return("5.19.0") + client := pluginapi.NewClient(api, &plugintest.Driver{}) + testServer := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + res.WriteHeader(http.StatusNotFound) + })) + defer testServer.Close() + url := testServer.URL + + _, err := client.Plugin.InstallPluginFromURL(url, false) + + assert.Error(t, err) + assert.Equal(t, "received 404 status code while downloading plugin from server", err.Error()) + }) +} + +func TestGetPluginAssetURL(t *testing.T) { + siteURL := "https://mattermost.example.com" + api := &plugintest.API{} + api.On("GetConfig").Return(&model.Config{ServiceSettings: model.ServiceSettings{SiteURL: &siteURL}}) + + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + t.Run("Valid asset directory was provided", func(t *testing.T) { + pluginID := "mattermost-1234" + dir := "assets" + wantedURL := "https://mattermost.example.com/mattermost-1234/assets" + gotURL, err := client.System.GetPluginAssetURL(pluginID, dir) + + assert.Equalf(t, wantedURL, gotURL, "GetPluginAssetURL(%q, %q) got=%q; want=%v", pluginID, dir, gotURL, wantedURL) + assert.NoError(t, err) + }) + + t.Run("Valid asset directory path was provided", func(t *testing.T) { + pluginID := "mattermost-1234" + dirPath := "/mattermost/assets" + wantedURL := "https://mattermost.example.com/mattermost-1234/mattermost/assets" + gotURL, err := client.System.GetPluginAssetURL(pluginID, dirPath) + + assert.Equalf(t, wantedURL, gotURL, "GetPluginAssetURL(%q, %q) got=%q; want=%q", pluginID, dirPath, gotURL, wantedURL) + assert.NoError(t, err) + }) + + t.Run("Valid pluginID was provided", func(t *testing.T) { + pluginID := "mattermost-1234" + dir := "assets" + wantedURL := "https://mattermost.example.com/mattermost-1234/assets" + gotURL, err := client.System.GetPluginAssetURL(pluginID, dir) + + assert.Equalf(t, wantedURL, gotURL, "GetPluginAssetURL(%q, %q) got=%q; want=%q", pluginID, dir, gotURL, wantedURL) + assert.NoError(t, err) + }) + + t.Run("Invalid asset directory name was provided", func(t *testing.T) { + pluginID := "mattermost-1234" + dir := "" + want := "" + gotURL, err := client.System.GetPluginAssetURL(pluginID, dir) + + assert.Emptyf(t, gotURL, "GetPluginAssetURL(%q, %q) got=%s; want=%q", pluginID, dir, gotURL, want) + assert.Error(t, err) + }) + + t.Run("Invalid pluginID was provided", func(t *testing.T) { + pluginID := "" + dir := "assets" + want := "" + gotURL, err := client.System.GetPluginAssetURL(pluginID, dir) + + assert.Emptyf(t, gotURL, "GetPluginAssetURL(%q, %q) got=%q; want=%q", pluginID, dir, gotURL, want) + assert.Error(t, err) + }) + + siteURL = "" + api.On("GetConfig").Return(&model.Config{ServiceSettings: model.ServiceSettings{SiteURL: &siteURL}}) + + t.Run("Empty SiteURL was configured", func(t *testing.T) { + pluginID := "mattermost-1234" + dir := "assets" + want := "" + gotURL, err := client.System.GetPluginAssetURL(pluginID, dir) + + assert.Emptyf(t, gotURL, "GetPluginAssetURL(%q, %q) got=%q; want=%q", pluginID, dir, gotURL, want) + assert.Error(t, err) + }) +} diff --git a/server/public/pluginapi/plugins.go b/server/public/pluginapi/plugins.go new file mode 100644 index 0000000000..babe7d2b23 --- /dev/null +++ b/server/public/pluginapi/plugins.go @@ -0,0 +1,114 @@ +package pluginapi + +import ( + "io" + "net/http" + "net/url" + "time" + + "github.com/pkg/errors" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" +) + +// PluginService exposes methods to manipulate the set of plugins as well as communicate with +// other plugin instances. +type PluginService struct { + api plugin.API +} + +// List will return a list of plugin manifests for currently active plugins. +// +// Minimum server version: 5.6 +func (p *PluginService) List() ([]*model.Manifest, error) { + manifests, appErr := p.api.GetPlugins() + + return manifests, normalizeAppErr(appErr) +} + +// Install will upload another plugin with tar.gz file. +// Previous version will be replaced on replace true. +// +// Minimum server version: 5.18 +func (p *PluginService) Install(file io.Reader, replace bool) (*model.Manifest, error) { + manifest, appErr := p.api.InstallPlugin(file, replace) + + return manifest, normalizeAppErr(appErr) +} + +// InstallPluginFromURL installs the plugin from the provided url. +// +// Minimum server version: 5.18 +func (p *PluginService) InstallPluginFromURL(downloadURL string, replace bool) (*model.Manifest, error) { + err := ensureServerVersion(p.api, "5.18.0") + if err != nil { + return nil, err + } + + parsedURL, err := url.Parse(downloadURL) + if err != nil { + return nil, errors.Wrap(err, "error while parsing url") + } + + client := &http.Client{Timeout: time.Hour} + response, err := client.Get(parsedURL.String()) + if err != nil { + return nil, errors.Wrap(err, "unable to download the plugin") + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return nil, errors.Errorf("received %d status code while downloading plugin from server", response.StatusCode) + } + + manifest, err := p.Install(response.Body, replace) + if err != nil { + return nil, errors.Wrap(err, "unable to install plugin on server") + } + + return manifest, nil +} + +// Enable will enable an plugin installed. +// +// Minimum server version: 5.6 +func (p *PluginService) Enable(id string) error { + appErr := p.api.EnablePlugin(id) + + return normalizeAppErr(appErr) +} + +// Disable will disable an enabled plugin. +// +// Minimum server version: 5.6 +func (p *PluginService) Disable(id string) error { + appErr := p.api.DisablePlugin(id) + + return normalizeAppErr(appErr) +} + +// Remove will disable and delete a plugin. +// +// Minimum server version: 5.6 +func (p *PluginService) Remove(id string) error { + appErr := p.api.RemovePlugin(id) + + return normalizeAppErr(appErr) +} + +// GetPluginStatus will return the status of a plugin. +// +// Minimum server version: 5.6 +func (p *PluginService) GetPluginStatus(id string) (*model.PluginStatus, error) { + pluginStatus, appErr := p.api.GetPluginStatus(id) + + return pluginStatus, normalizeAppErr(appErr) +} + +// HTTP allows inter-plugin requests to plugin APIs. +// +// Minimum server version: 5.18 +func (p *PluginService) HTTP(request *http.Request) *http.Response { + return p.api.PluginHTTP(request) +} diff --git a/server/public/pluginapi/post.go b/server/public/pluginapi/post.go new file mode 100644 index 0000000000..e4328fe4dd --- /dev/null +++ b/server/public/pluginapi/post.go @@ -0,0 +1,335 @@ +package pluginapi + +import ( + "github.com/pkg/errors" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" +) + +// PostService exposes methods to manipulate posts. +type PostService struct { + api plugin.API +} + +// CreatePost creates a post. +// +// Minimum server version: 5.2 +func (p *PostService) CreatePost(post *model.Post) error { + createdPost, appErr := p.api.CreatePost(post) + if appErr != nil { + return normalizeAppErr(appErr) + } + + err := createdPost.ShallowCopy(post) + if err != nil { + return err + } + + return nil +} + +// DM sends a post as a direct message +// +// Minimum server version: 5.2 +func (p *PostService) DM(senderUserID, receiverUserID string, post *model.Post) error { + channel, appErr := p.api.GetDirectChannel(senderUserID, receiverUserID) + if appErr != nil { + return normalizeAppErr(appErr) + } + post.ChannelId = channel.Id + post.UserId = senderUserID + return p.CreatePost(post) +} + +// GetPost gets a post. +// +// Minimum server version: 5.2 +func (p *PostService) GetPost(postID string) (*model.Post, error) { + post, appErr := p.api.GetPost(postID) + + return post, normalizeAppErr(appErr) +} + +// UpdatePost updates a post. +// +// Minimum server version: 5.2 +func (p *PostService) UpdatePost(post *model.Post) error { + updatedPost, appErr := p.api.UpdatePost(post) + if appErr != nil { + return normalizeAppErr(appErr) + } + + err := updatedPost.ShallowCopy(post) + if err != nil { + return err + } + + return nil +} + +// DeletePost deletes a post. +// +// Minimum server version: 5.2 +func (p *PostService) DeletePost(postID string) error { + return normalizeAppErr(p.api.DeletePost(postID)) +} + +// SendEphemeralPost creates an ephemeral post. +// +// Minimum server version: 5.2 +func (p *PostService) SendEphemeralPost(userID string, post *model.Post) { + *post = *p.api.SendEphemeralPost(userID, post) +} + +// UpdateEphemeralPost updates an ephemeral message previously sent to the user. +// EXPERIMENTAL: This API is experimental and can be changed without advance notice. +// +// Minimum server version: 5.2 +func (p *PostService) UpdateEphemeralPost(userID string, post *model.Post) { + *post = *p.api.UpdateEphemeralPost(userID, post) +} + +// DeleteEphemeralPost deletes an ephemeral message previously sent to the user. +// EXPERIMENTAL: This API is experimental and can be changed without advance notice. +// +// Minimum server version: 5.2 +func (p *PostService) DeleteEphemeralPost(userID, postID string) { + p.api.DeleteEphemeralPost(userID, postID) +} + +// GetPostThread gets a post with all the other posts in the same thread. +// +// Minimum server version: 5.6 +func (p *PostService) GetPostThread(postID string) (*model.PostList, error) { + postList, appErr := p.api.GetPostThread(postID) + + return postList, normalizeAppErr(appErr) +} + +// GetPostsSince gets posts created after a specified time as Unix time in milliseconds. +// +// Minimum server version: 5.6 +func (p *PostService) GetPostsSince(channelID string, time int64) (*model.PostList, error) { + postList, appErr := p.api.GetPostsSince(channelID, time) + + return postList, normalizeAppErr(appErr) +} + +// GetPostsAfter gets a page of posts that were posted after the post provided. +// +// Minimum server version: 5.6 +func (p *PostService) GetPostsAfter(channelID, postID string, page, perPage int) (*model.PostList, error) { + postList, appErr := p.api.GetPostsAfter(channelID, postID, page, perPage) + + return postList, normalizeAppErr(appErr) +} + +// GetPostsBefore gets a page of posts that were posted before the post provided. +// +// Minimum server version: 5.6 +func (p *PostService) GetPostsBefore(channelID, postID string, page, perPage int) (*model.PostList, error) { + postList, appErr := p.api.GetPostsBefore(channelID, postID, page, perPage) + + return postList, normalizeAppErr(appErr) +} + +// GetPostsForChannel gets a list of posts for a channel. +// +// Minimum server version: 5.6 +func (p *PostService) GetPostsForChannel(channelID string, page, perPage int) (*model.PostList, error) { + postList, appErr := p.api.GetPostsForChannel(channelID, page, perPage) + + return postList, normalizeAppErr(appErr) +} + +// SearchPostsInTeam returns a list of posts in a specific team that match the given params. +// +// Minimum server version: 5.10 +func (p *PostService) SearchPostsInTeam(teamID string, paramsList []*model.SearchParams) ([]*model.Post, error) { + postList, appErr := p.api.SearchPostsInTeam(teamID, paramsList) + + return postList, normalizeAppErr(appErr) +} + +// AddReaction add a reaction to a post. +// +// Minimum server version: 5.3 +func (p *PostService) AddReaction(reaction *model.Reaction) error { + addedReaction, appErr := p.api.AddReaction(reaction) + if appErr != nil { + return normalizeAppErr(appErr) + } + + *reaction = *addedReaction + + return nil +} + +// GetReactions get the reactions of a post. +// +// Minimum server version: 5.3 +func (p *PostService) GetReactions(postID string) ([]*model.Reaction, error) { + reactions, appErr := p.api.GetReactions(postID) + + return reactions, normalizeAppErr(appErr) +} + +// RemoveReaction remove a reaction from a post. +// +// Minimum server version: 5.3 +func (p *PostService) RemoveReaction(reaction *model.Reaction) error { + return normalizeAppErr(p.api.RemoveReaction(reaction)) +} + +type ShouldProcessMessageOption func(*shouldProcessMessageOptions) + +type shouldProcessMessageOptions struct { + AllowSystemMessages bool + AllowBots bool + AllowWebhook bool + FilterChannelIDs []string + FilterUserIDs []string + OnlyBotDMs bool + BotID string +} + +// AllowSystemMessages configures a call to ShouldProcessMessage to return true for system messages. +// +// As it is typically desirable only to consume messages from users of the system, ShouldProcessMessage ignores system messages by default. +func AllowSystemMessages() ShouldProcessMessageOption { + return func(options *shouldProcessMessageOptions) { + options.AllowSystemMessages = true + } +} + +// AllowBots configures a call to ShouldProcessMessage to return true for bot posts. +// +// As it is typically desirable only to consume messages from human users of the system, ShouldProcessMessage ignores bot messages by default. +// When allowing bots, take care to avoid a loop where two plugins respond to each others posts repeatedly. +func AllowBots() ShouldProcessMessageOption { + return func(options *shouldProcessMessageOptions) { + options.AllowBots = true + } +} + +// AllowWebhook configures a call to ShouldProcessMessage to return true for posts from webhook. +// +// As it is typically desirable only to consume messages from human users of the system, ShouldProcessMessage ignores webhook messages by default. +func AllowWebhook() ShouldProcessMessageOption { + return func(options *shouldProcessMessageOptions) { + options.AllowWebhook = true + } +} + +// FilterChannelIDs configures a call to ShouldProcessMessage to return true only for the given channels. +// +// By default, posts from all channels are allowed to be processed. +func FilterChannelIDs(filterChannelIDs []string) ShouldProcessMessageOption { + return func(options *shouldProcessMessageOptions) { + options.FilterChannelIDs = filterChannelIDs + } +} + +// FilterUserIDs configures a call to ShouldProcessMessage to return true only for the given users. +// +// By default, posts from all non-bot users are allowed. +func FilterUserIDs(filterUserIDs []string) ShouldProcessMessageOption { + return func(options *shouldProcessMessageOptions) { + options.FilterUserIDs = filterUserIDs + } +} + +// OnlyBotDMs configures a call to ShouldProcessMessage to return true only for direct messages sent to the bot created by EnsureBot. +// +// By default, posts from all channels are allowed. +func OnlyBotDMs() ShouldProcessMessageOption { + return func(options *shouldProcessMessageOptions) { + options.OnlyBotDMs = true + } +} + +// If provided, BotID configures ShouldProcessMessage to skip its retrieval from the store. +// +// By default, posts from all non-bot users are allowed. +func BotID(botID string) ShouldProcessMessageOption { + return func(options *shouldProcessMessageOptions) { + options.BotID = botID + } +} + +// ShouldProcessMessage returns if the message should be processed by a message hook. +// +// Use this method to avoid processing unnecessary messages in a MessageHasBeenPosted +// or MessageWillBePosted hook, and indeed in some cases avoid an infinite loop between +// two automated bots or plugins. +// +// The behavior is customizable using the given options, since plugin needs may vary. +// By default, system messages and messages from bots will be skipped. +// +// Minimum server version: 5.2 +func (p *PostService) ShouldProcessMessage(post *model.Post, options ...ShouldProcessMessageOption) (bool, error) { + messageProcessOptions := &shouldProcessMessageOptions{} + for _, option := range options { + option(messageProcessOptions) + } + + var botIDBytes []byte + var kvGetErr *model.AppError + + if messageProcessOptions.BotID != "" { + botIDBytes = []byte(messageProcessOptions.BotID) + } else { + botIDBytes, kvGetErr = p.api.KVGet(botUserKey) + + if kvGetErr != nil { + return false, errors.Wrap(kvGetErr, "failed to get bot") + } + } + + if botIDBytes != nil { + if post.UserId == string(botIDBytes) { + return false, nil + } + } + + if post.IsSystemMessage() && !messageProcessOptions.AllowSystemMessages { + return false, nil + } + + if !messageProcessOptions.AllowWebhook && post.GetProp("from_webhook") == "true" { + return false, nil + } + + if !messageProcessOptions.AllowBots { + user, appErr := p.api.GetUser(post.UserId) + if appErr != nil { + return false, errors.Wrap(appErr, "unable to get user") + } + + if user.IsBot { + return false, nil + } + } + + if len(messageProcessOptions.FilterChannelIDs) != 0 && !stringInSlice(post.ChannelId, messageProcessOptions.FilterChannelIDs) { + return false, nil + } + + if len(messageProcessOptions.FilterUserIDs) != 0 && !stringInSlice(post.UserId, messageProcessOptions.FilterUserIDs) { + return false, nil + } + + if botIDBytes != nil && messageProcessOptions.OnlyBotDMs { + channel, appErr := p.api.GetChannel(post.ChannelId) + if appErr != nil { + return false, errors.Wrap(appErr, "unable to get channel") + } + + if !model.IsBotDMChannel(channel, string(botIDBytes)) { + return false, nil + } + } + + return true, nil +} diff --git a/server/public/pluginapi/post_test.go b/server/public/pluginapi/post_test.go new file mode 100644 index 0000000000..8c9aa0ff3d --- /dev/null +++ b/server/public/pluginapi/post_test.go @@ -0,0 +1,741 @@ +package pluginapi_test + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/mattermost/mattermost/server/public/pluginapi" +) + +func TestCreatePost(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + id := model.NewId() + in := &model.Post{ + Id: "postID", + } + out := in.Clone() + out.Id = id + api.On("CreatePost", in).Return(out, nil) + + err := client.Post.CreatePost(in) + require.NoError(t, err) + assert.Equal(t, out, in) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + in := &model.Post{ + Id: "postID", + } + out := in.Clone() + api.On("CreatePost", in).Return(nil, newAppError()) + + err := client.Post.CreatePost(in) + require.EqualError(t, err, "here: id, an error occurred") + assert.Equal(t, out, in) + }) +} + +func TestGetPost(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + postID := "postID" + expectedPost := &model.Post{ + Id: postID, + } + api.On("GetPost", postID).Return(expectedPost, nil) + + actualPost, err := client.Post.GetPost(postID) + require.NoError(t, err) + assert.Equal(t, expectedPost, actualPost) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + postID := "postID" + api.On("GetPost", postID).Return(nil, newAppError()) + + actualPost, err := client.Post.GetPost(postID) + require.EqualError(t, err, "here: id, an error occurred") + assert.Nil(t, actualPost) + }) +} + +func TestUpdatePost(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + now := model.GetMillis() + in := &model.Post{ + Id: "postID", + } + out := in.Clone() + out.UpdateAt = now + api.On("UpdatePost", in).Return(out, nil) + + err := client.Post.UpdatePost(in) + require.NoError(t, err) + require.Equal(t, out, in) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + in := &model.Post{ + Id: "postID", + } + out := in.Clone() + api.On("UpdatePost", in).Return(nil, newAppError()) + + err := client.Post.UpdatePost(in) + require.EqualError(t, err, "here: id, an error occurred") + assert.Equal(t, out, in) + }) +} + +func TestDeletePost(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + postID := "postID" + + api.On("DeletePost", postID).Return(nil) + + err := client.Post.DeletePost(postID) + require.NoError(t, err) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + postID := "postID" + api.On("DeletePost", postID).Return(newAppError()) + + err := client.Post.DeletePost(postID) + require.EqualError(t, err, "here: id, an error occurred") + }) +} + +func TestSendEphemeralPost(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + userID := "userID" + expectedPost := &model.Post{ + Id: "postID", + } + api.On("SendEphemeralPost", userID, expectedPost).Return(expectedPost, nil) + + client.Post.SendEphemeralPost(userID, expectedPost) +} + +func TestUpdateEphemeralPost(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + now := model.GetMillis() + userID := "userID" + in := &model.Post{ + Id: "postID", + } + out := in.Clone() + out.UpdateAt = now + api.On("UpdateEphemeralPost", userID, in).Return(out, nil) + + client.Post.UpdateEphemeralPost(userID, in) + assert.Equal(t, out, in) +} + +func TestDeleteEphemeralPost(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + userID := "userID" + postID := "postID" + api.On("DeleteEphemeralPost", userID, postID).Return() + + client.Post.DeleteEphemeralPost(userID, postID) +} + +func TestGetPostThread(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + postID := "postID" + expectedPostList := model.NewPostList() + expectedPostList.AddPost(&model.Post{Id: postID}) + + api.On("GetPostThread", postID).Return(expectedPostList, nil) + + actualPostList, err := client.Post.GetPostThread(postID) + require.NoError(t, err) + assert.Equal(t, expectedPostList, actualPostList) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + postID := "postID" + api.On("GetPostThread", postID).Return(nil, newAppError()) + + actualPostList, err := client.Post.GetPostThread(postID) + require.EqualError(t, err, "here: id, an error occurred") + assert.Nil(t, actualPostList) + }) +} + +func TestGetPostsSince(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + channelID := "channelID" + time := int64(0) + expectedPostList := model.NewPostList() + expectedPostList.AddPost(&model.Post{ChannelId: channelID}) + + api.On("GetPostsSince", channelID, time).Return(expectedPostList, nil) + + actualPostList, err := client.Post.GetPostsSince(channelID, time) + require.NoError(t, err) + assert.Equal(t, expectedPostList, actualPostList) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + channelID := "channelID" + time := int64(0) + api.On("GetPostsSince", channelID, time).Return(nil, newAppError()) + + actualPostList, err := client.Post.GetPostsSince(channelID, time) + require.EqualError(t, err, "here: id, an error occurred") + assert.Nil(t, actualPostList) + }) +} + +func TestGetPostsAfter(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + channelID := "channelID" + postID := "postID" + expectedPostList := model.NewPostList() + expectedPostList.AddPost(&model.Post{ChannelId: channelID}) + + api.On("GetPostsAfter", channelID, postID, 0, 0).Return(expectedPostList, nil) + + actualPostList, err := client.Post.GetPostsAfter(channelID, postID, 0, 0) + require.NoError(t, err) + assert.Equal(t, expectedPostList, actualPostList) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + channelID := "channelID" + postID := "postID" + api.On("GetPostsAfter", channelID, postID, 0, 0).Return(nil, newAppError()) + + actualPostList, err := client.Post.GetPostsAfter(channelID, postID, 0, 0) + require.EqualError(t, err, "here: id, an error occurred") + assert.Nil(t, actualPostList) + }) +} + +func TestGetPostsBefore(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + channelID := "channelID" + postID := "postID" + expectedPostList := model.NewPostList() + expectedPostList.AddPost(&model.Post{ChannelId: channelID}) + + api.On("GetPostsBefore", channelID, postID, 0, 0).Return(expectedPostList, nil) + + actualPostList, err := client.Post.GetPostsBefore(channelID, postID, 0, 0) + require.NoError(t, err) + assert.Equal(t, expectedPostList, actualPostList) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + channelID := "channelID" + postID := "postID" + api.On("GetPostsBefore", channelID, postID, 0, 0).Return(nil, newAppError()) + + actualPostList, err := client.Post.GetPostsBefore(channelID, postID, 0, 0) + require.EqualError(t, err, "here: id, an error occurred") + assert.Nil(t, actualPostList) + }) +} + +func TestGetPostsForChannel(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + channelID := "channelID" + expectedPostList := model.NewPostList() + expectedPostList.AddPost(&model.Post{ChannelId: channelID}) + + api.On("GetPostsForChannel", channelID, 0, 0).Return(expectedPostList, nil) + + actualPostList, err := client.Post.GetPostsForChannel(channelID, 0, 0) + require.NoError(t, err) + assert.Equal(t, expectedPostList, actualPostList) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + channelID := "channelID" + api.On("GetPostsForChannel", channelID, 0, 0).Return(nil, newAppError()) + + actualPostList, err := client.Post.GetPostsForChannel(channelID, 0, 0) + require.EqualError(t, err, "here: id, an error occurred") + assert.Nil(t, actualPostList) + }) +} + +func TestSearchPostsInTeam(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + teamID := "teamID" + searchParams := []*model.SearchParams{{InChannels: []string{"channelID"}}} + expectedPosts := []*model.Post{{ChannelId: "channelID"}} + api.On("SearchPostsInTeam", teamID, searchParams).Return(expectedPosts, nil) + + actualPostList, err := client.Post.SearchPostsInTeam(teamID, searchParams) + require.NoError(t, err) + assert.Equal(t, expectedPosts, actualPostList) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + teamID := "teamID" + searchParams := []*model.SearchParams{{InChannels: []string{"channelID"}}} + api.On("SearchPostsInTeam", teamID, searchParams).Return(nil, newAppError()) + + actualPostList, err := client.Post.SearchPostsInTeam(teamID, searchParams) + require.EqualError(t, err, "here: id, an error occurred") + assert.Nil(t, actualPostList) + }) +} + +func TestAddReaction(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + in := &model.Reaction{ + PostId: "postId", + } + api.On("AddReaction", in).Return(in, nil) + + err := client.Post.AddReaction(in) + require.NoError(t, err) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + in := &model.Reaction{ + PostId: "postId", + } + api.On("AddReaction", in).Return(nil, newAppError()) + + err := client.Post.AddReaction(in) + require.EqualError(t, err, "here: id, an error occurred") + }) +} + +func TestGetReactions(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + postID := "postID" + expectedReactions := []*model.Reaction{ + {PostId: postID, UserId: "user1"}, + {PostId: postID, UserId: "user2"}, + } + api.On("GetReactions", postID).Return(expectedReactions, nil) + + actualReactions, err := client.Post.GetReactions(postID) + require.NoError(t, err) + assert.Equal(t, expectedReactions, actualReactions) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + postID := "postID" + api.On("GetReactions", postID).Return(nil, newAppError()) + + actualReactions, err := client.Post.GetReactions(postID) + require.EqualError(t, err, "here: id, an error occurred") + assert.Nil(t, actualReactions) + }) +} + +func TestDeleteReaction(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + reaction := &model.Reaction{ + PostId: "postId", + } + api.On("RemoveReaction", reaction).Return(nil) + + err := client.Post.RemoveReaction(reaction) + require.NoError(t, err) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + reaction := &model.Reaction{ + PostId: "postId", + } + api.On("RemoveReaction", reaction).Return(newAppError()) + + err := client.Post.RemoveReaction(reaction) + require.EqualError(t, err, "here: id, an error occurred") + }) +} + +func TestSearchTeamPosts(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("SearchPostsInTeam", "1", []*model.SearchParams{{Terms: "2"}, {Terms: "3"}}). + Return([]*model.Post{{Id: "3"}, {Id: "4"}}, nil) + + posts, err := client.Post.SearchPostsInTeam("1", []*model.SearchParams{{Terms: "2"}, {Terms: "3"}}) + require.NoError(t, err) + require.Equal(t, []*model.Post{{Id: "3"}, {Id: "4"}}, posts) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("SearchPostsInTeam", "1", []*model.SearchParams{{Terms: "2"}, {Terms: "3"}}). + Return(nil, appErr) + + posts, err := client.Post.SearchPostsInTeam("1", []*model.SearchParams{{Terms: "2"}, {Terms: "3"}}) + require.Equal(t, appErr, err) + require.Len(t, posts, 0) + }) +} + +func TestShouldProcessMessage(t *testing.T) { + expectedBotID := model.NewId() + + setupAPI := func() *plugintest.API { + return &plugintest.API{} + } + + t.Run("should not respond to itself", func(t *testing.T) { + api := setupAPI() + api.On("KVGet", plugin.BotUserKey).Return([]byte(expectedBotID), nil) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + shouldProcessMessage, err := client.Post.ShouldProcessMessage( + &model.Post{Type: model.PostTypeHeaderChange, UserId: expectedBotID}, + pluginapi.AllowSystemMessages(), + pluginapi.AllowBots(), + ) + + assert.NoError(t, err) + assert.False(t, shouldProcessMessage) + }) + + t.Run("should not process as the post is generated by system", func(t *testing.T) { + api := setupAPI() + api.On("KVGet", plugin.BotUserKey).Return([]byte(expectedBotID), nil) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + shouldProcessMessage, err := client.Post.ShouldProcessMessage( + &model.Post{Type: model.PostTypeHeaderChange}, + ) + + assert.NoError(t, err) + assert.False(t, shouldProcessMessage) + }) + + t.Run("should not process as the post is sent to another channel", func(t *testing.T) { + channelID := "channel-id" + api := setupAPI() + api.On("GetChannel", channelID).Return(&model.Channel{Id: channelID, Type: model.ChannelTypeGroup}, nil) + api.On("KVGet", plugin.BotUserKey).Return([]byte(expectedBotID), nil) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + shouldProcessMessage, err := client.Post.ShouldProcessMessage( + &model.Post{ChannelId: channelID}, pluginapi.AllowSystemMessages(), + pluginapi.AllowBots(), + pluginapi.FilterChannelIDs([]string{"another-channel-id"}), + ) + + assert.NoError(t, err) + assert.False(t, shouldProcessMessage) + }) + + t.Run("should not process as the post is created by bot", func(t *testing.T) { + userID := "user-id" + channelID := "1" + api := setupAPI() + api.On("GetUser", userID).Return(&model.User{IsBot: true}, nil) + api.On("KVGet", plugin.BotUserKey).Return([]byte(expectedBotID), nil) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + shouldProcessMessage, err := client.Post.ShouldProcessMessage( + &model.Post{UserId: userID, ChannelId: channelID}, + pluginapi.AllowSystemMessages(), + pluginapi.FilterUserIDs([]string{"another-user-id"}), + ) + + assert.NoError(t, err) + assert.False(t, shouldProcessMessage) + }) + + t.Run("should not process the message as the post is not in bot dm channel", func(t *testing.T) { + userID := "user-id" + channelID := "1" + channel := model.Channel{ + Name: "user1__" + expectedBotID, + Type: model.ChannelTypeOpen, + } + api := setupAPI() + api.On("GetChannel", channelID).Return(&channel, nil) + api.On("KVGet", plugin.BotUserKey).Return([]byte(expectedBotID), nil) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + shouldProcessMessage, err := client.Post.ShouldProcessMessage( + &model.Post{UserId: userID, ChannelId: channelID}, + pluginapi.AllowSystemMessages(), + pluginapi.AllowBots(), + pluginapi.OnlyBotDMs(), + ) + + assert.NoError(t, err) + assert.False(t, shouldProcessMessage) + }) + + t.Run("should process the message", func(t *testing.T) { + channelID := "1" + api := setupAPI() + api.On("KVGet", plugin.BotUserKey).Return([]byte(expectedBotID), nil) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + shouldProcessMessage, err := client.Post.ShouldProcessMessage( + &model.Post{UserId: "1", Type: model.PostTypeHeaderChange, ChannelId: channelID}, + pluginapi.AllowSystemMessages(), + pluginapi.FilterChannelIDs([]string{channelID}), + pluginapi.AllowBots(), + pluginapi.FilterUserIDs([]string{"1"}), + ) + + assert.NoError(t, err) + assert.True(t, shouldProcessMessage) + }) + + t.Run("should process the message for plugin without a bot", func(t *testing.T) { + channelID := "1" + api := setupAPI() + api.On("KVGet", plugin.BotUserKey).Return(nil, nil) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + shouldProcessMessage, err := client.Post.ShouldProcessMessage( + &model.Post{UserId: "1", Type: model.PostTypeHeaderChange, ChannelId: channelID}, + pluginapi.AllowSystemMessages(), + pluginapi.FilterChannelIDs([]string{channelID}), + pluginapi.AllowBots(), + pluginapi.FilterUserIDs([]string{"1"}), + ) + + assert.NoError(t, err) + assert.True(t, shouldProcessMessage) + }) + + t.Run("should process the message when filter channel and filter users list is empty", func(t *testing.T) { + channelID := "1" + api := setupAPI() + channel := model.Channel{ + Name: "user1__" + expectedBotID, + Type: model.ChannelTypeDirect, + } + api.On("GetChannel", channelID).Return(&channel, nil) + api.On("KVGet", plugin.BotUserKey).Return([]byte(expectedBotID), nil) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + shouldProcessMessage, err := client.Post.ShouldProcessMessage( + &model.Post{UserId: "1", Type: model.PostTypeHeaderChange, ChannelId: channelID}, + pluginapi.AllowSystemMessages(), + pluginapi.AllowBots(), + ) + + assert.NoError(t, err) + assert.True(t, shouldProcessMessage) + }) + + t.Run("should not process the message which have from_webhook", func(t *testing.T) { + channelID := "1" + api := setupAPI() + api.On("GetChannel", channelID).Return(&model.Channel{Id: channelID, Type: model.ChannelTypeGroup}, nil) + + api.On("KVGet", plugin.BotUserKey).Return([]byte(expectedBotID), nil) + + client := pluginapi.NewClient(api, &plugintest.Driver{}) + shouldProcessMessage, err := client.Post.ShouldProcessMessage( + &model.Post{ChannelId: channelID, Props: model.StringInterface{"from_webhook": "true"}}, + pluginapi.AllowBots(), + ) + + assert.NoError(t, err) + assert.False(t, shouldProcessMessage) + }) + + t.Run("should process the message which have from_webhook with allow webhook plugin", func(t *testing.T) { + channelID := "1" + api := setupAPI() + api.On("GetChannel", channelID).Return(&model.Channel{Id: channelID, Type: model.ChannelTypeGroup}, nil) + + api.On("KVGet", plugin.BotUserKey).Return([]byte(expectedBotID), nil) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + shouldProcessMessage, err := client.Post.ShouldProcessMessage( + &model.Post{ChannelId: channelID, Props: model.StringInterface{"from_webhook": "true"}}, + pluginapi.AllowBots(), + pluginapi.AllowWebhook(), + ) + + assert.NoError(t, err) + assert.True(t, shouldProcessMessage) + }) + + t.Run("should process the message where from_webhook is not set", func(t *testing.T) { + channelID := "1" + api := setupAPI() + api.On("GetChannel", channelID).Return(&model.Channel{Id: channelID, Type: model.ChannelTypeGroup}, nil) + + api.On("KVGet", plugin.BotUserKey).Return([]byte(expectedBotID), nil) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + shouldProcessMessage, err := client.Post.ShouldProcessMessage( + &model.Post{ChannelId: channelID}, + pluginapi.AllowBots(), + ) + + assert.NoError(t, err) + assert.True(t, shouldProcessMessage) + }) + + t.Run("should process the message which have from_webhook false", func(t *testing.T) { + channelID := "1" + api := setupAPI() + api.On("GetChannel", channelID).Return(&model.Channel{Id: channelID, Type: model.ChannelTypeGroup}, nil) + + api.On("KVGet", plugin.BotUserKey).Return([]byte(expectedBotID), nil) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + shouldProcessMessage, err := client.Post.ShouldProcessMessage( + &model.Post{ChannelId: channelID, Props: model.StringInterface{"from_webhook": "false"}}, + pluginapi.AllowBots(), + ) + + assert.NoError(t, err) + assert.True(t, shouldProcessMessage) + }) + + t.Run("should process the message when we pass the botId as input", func(t *testing.T) { + userID := "user-id" + channelID := "1" + api := setupAPI() + api.On("GetChannel", channelID).Return(&model.Channel{Id: channelID, Type: model.ChannelTypeGroup}, nil) + + api.On("GetUser", userID).Return(&model.User{IsBot: false}, nil) + + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + shouldProcessMessage, err := client.Post.ShouldProcessMessage( + &model.Post{ChannelId: channelID, UserId: userID}, + pluginapi.BotID(expectedBotID), + ) + assert.NoError(t, err) + + assert.True(t, shouldProcessMessage) + }) +} diff --git a/server/public/pluginapi/session.go b/server/public/pluginapi/session.go new file mode 100644 index 0000000000..05c11408c8 --- /dev/null +++ b/server/public/pluginapi/session.go @@ -0,0 +1,43 @@ +package pluginapi + +import ( + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" +) + +// SessionService exposes methods to manipulate groups. +type SessionService struct { + api plugin.API +} + +// Get returns the session object for the Session ID +// +// Minimum server version: 5.2 +func (s *SessionService) Get(id string) (*model.Session, error) { + session, appErr := s.api.GetSession(id) + + return session, normalizeAppErr(appErr) +} + +// Create creates a new user session. +// +// Minimum server version: 6.2 +func (s *SessionService) Create(session *model.Session) (*model.Session, error) { + session, appErr := s.api.CreateSession(session) + + return session, normalizeAppErr(appErr) +} + +// ExtendSessionExpiry extends the duration of an existing session. +// +// Minimum server version: 6.2 +func (s *SessionService) ExtendExpiry(sessionID string, newExpiry int64) error { + return normalizeAppErr(s.api.ExtendSessionExpiry(sessionID, newExpiry)) +} + +// RevokeSession revokes an existing user session. +// +// Minimum server version: 6.2 +func (s *SessionService) Revoke(sessionID string) error { + return normalizeAppErr(s.api.RevokeSession(sessionID)) +} diff --git a/server/public/pluginapi/slashcommand.go b/server/public/pluginapi/slashcommand.go new file mode 100644 index 0000000000..7fe939e543 --- /dev/null +++ b/server/public/pluginapi/slashcommand.go @@ -0,0 +1,100 @@ +package pluginapi + +import ( + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" +) + +// SlashCommandService exposes methods to manipulate slash commands. +type SlashCommandService struct { + api plugin.API +} + +// Register registers a custom slash command. When the command is triggered, your plugin +// can fulfill it via the ExecuteCommand hook. +// +// Minimum server version: 5.2 +func (c *SlashCommandService) Register(command *model.Command) error { + return c.api.RegisterCommand(command) +} + +// Unregister unregisters a command previously registered via Register. +// +// Minimum server version: 5.2 +func (c *SlashCommandService) Unregister(teamID, trigger string) error { + return c.api.UnregisterCommand(teamID, trigger) +} + +// Execute executes a slash command. +// +// Minimum server version: 5.26 +func (c *SlashCommandService) Execute(command *model.CommandArgs) (*model.CommandResponse, error) { + return c.api.ExecuteSlashCommand(command) +} + +// Create creates a server-owned slash command that is not handled by the plugin +// itself, and which will persist past the life of the plugin. The command will have its +// CreatorId set to "" and its PluginId set to the id of the plugin that created it. +// +// Minimum server version: 5.28 +func (c *SlashCommandService) Create(command *model.Command) (*model.Command, error) { + return c.api.CreateCommand(command) +} + +// List returns the list of all slash commands for teamID. E.g., custom commands +// (those created through the integrations menu, the REST api, or the plugin api CreateCommand), +// plugin commands (those created with plugin api RegisterCommand), and builtin commands +// (those added internally through RegisterCommandProvider). +// +// Minimum server version: 5.28 +func (c *SlashCommandService) List(teamID string) ([]*model.Command, error) { + return c.api.ListCommands(teamID) +} + +// ListCustom returns the list of slash commands for teamID that where created +// through the integrations menu, the REST api, or the plugin api CreateCommand. +// +// Minimum server version: 5.28 +func (c *SlashCommandService) ListCustom(teamID string) ([]*model.Command, error) { + return c.api.ListCustomCommands(teamID) +} + +// ListPlugin returns the list of slash commands for teamID that were created +// with the plugin api RegisterCommand. +// +// Minimum server version: 5.28 +func (c *SlashCommandService) ListPlugin(teamID string) ([]*model.Command, error) { + return c.api.ListPluginCommands(teamID) +} + +// ListBuiltIn returns the list of slash commands that are builtin commands +// (those added internally through RegisterCommandProvider). +// +// Minimum server version: 5.28 +func (c *SlashCommandService) ListBuiltIn() ([]*model.Command, error) { + return c.api.ListBuiltInCommands() +} + +// Get returns the command definition based on a command id string. +// +// Minimum server version: 5.28 +func (c *SlashCommandService) Get(commandID string) (*model.Command, error) { + return c.api.GetCommand(commandID) +} + +// Update updates a single command (identified by commandID) with the information provided in the +// updatedCmd model.Command struct. The following fields in the command cannot be updated: +// Id, Token, CreateAt, DeleteAt, and PluginId. If updatedCmd.TeamId is blank, it +// will be set to commandID's TeamId. +// +// Minimum server version: 5.28 +func (c *SlashCommandService) Update(commandID string, updatedCmd *model.Command) (*model.Command, error) { + return c.api.UpdateCommand(commandID, updatedCmd) +} + +// Delete deletes a slash command (identified by commandID). +// +// Minimum server version: 5.28 +func (c *SlashCommandService) Delete(commandID string) error { + return c.api.DeleteCommand(commandID) +} diff --git a/server/public/pluginapi/store.go b/server/public/pluginapi/store.go new file mode 100644 index 0000000000..c0daee7c80 --- /dev/null +++ b/server/public/pluginapi/store.go @@ -0,0 +1,117 @@ +package pluginapi + +import ( + "database/sql" + "sync" + + // import sql drivers + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + + "github.com/mattermost/mattermost/server/public/plugin" + "github.com/mattermost/mattermost/server/public/shared/driver" + "github.com/pkg/errors" +) + +// StoreService exposes the underlying database. +type StoreService struct { + initialized bool + api plugin.API + driver plugin.Driver + mutex sync.Mutex + + masterDB *sql.DB + replicaDB *sql.DB +} + +// GetMasterDB gets the master database handle. +// +// Minimum server version: 5.16 +func (s *StoreService) GetMasterDB() (*sql.DB, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if err := s.initialize(); err != nil { + return nil, err + } + + return s.masterDB, nil +} + +// GetReplicaDB gets the replica database handle. +// Returns masterDB if a replica is not configured. +// +// Minimum server version: 5.16 +func (s *StoreService) GetReplicaDB() (*sql.DB, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if err := s.initialize(); err != nil { + return nil, err + } + + if s.replicaDB != nil { + return s.replicaDB, nil + } + + return s.masterDB, nil +} + +// Close closes any open resources. This method is idempotent. +func (s *StoreService) Close() error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if !s.initialized { + return nil + } + + if err := s.masterDB.Close(); err != nil { + return err + } + + if s.replicaDB != nil { + if err := s.replicaDB.Close(); err != nil { + return err + } + } + + return nil +} + +// DriverName returns the driver name for the datasource. +func (s *StoreService) DriverName() string { + return *s.api.GetConfig().SqlSettings.DriverName +} + +func (s *StoreService) initialize() error { + if s.initialized { + return nil + } + + if s.driver == nil { + return errors.New("no db driver was provided") + } + + config := s.api.GetUnsanitizedConfig() + + // Set up master db + db := sql.OpenDB(driver.NewConnector(s.driver, true)) + if err := db.Ping(); err != nil { + return errors.Wrap(err, "failed to connect to master db") + } + s.masterDB = db + + // Set up replica db + if len(config.SqlSettings.DataSourceReplicas) > 0 { + db := sql.OpenDB(driver.NewConnector(s.driver, false)) + if err := db.Ping(); err != nil { + return errors.Wrap(err, "failed to connect to replica db") + } + s.replicaDB = db + } + + s.initialized = true + + return nil +} diff --git a/server/public/pluginapi/store_test.go b/server/public/pluginapi/store_test.go new file mode 100644 index 0000000000..78b2b98f47 --- /dev/null +++ b/server/public/pluginapi/store_test.go @@ -0,0 +1,109 @@ +package pluginapi_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/mattermost/mattermost/server/public/pluginapi" +) + +func TestStore(t *testing.T) { + t.Run("master db singleton", func(t *testing.T) { + config := &model.Config{ + SqlSettings: model.SqlSettings{ + DriverName: model.NewString("test"), + DataSource: model.NewString("TestStore-master-db"), + }, + } + + api := &plugintest.API{} + defer api.AssertExpectations(t) + api.On("GetUnsanitizedConfig").Return(config) + + driver := &plugintest.Driver{} + driver.On("Conn", true).Return("test", nil) + driver.On("ConnPing", "test").Return(nil) + driver.On("ConnClose", "test").Return(nil) + + store := pluginapi.NewClient(api, driver).Store + + db1, err := store.GetMasterDB() + require.NoError(t, err) + require.NotNil(t, db1) + + db2, err := store.GetMasterDB() + require.NoError(t, err) + require.NotNil(t, db2) + + require.Same(t, db1, db2) + require.NoError(t, store.Close()) + }) + + t.Run("master db fallback", func(t *testing.T) { + config := &model.Config{ + SqlSettings: model.SqlSettings{ + DriverName: model.NewString("ramsql"), + DataSource: model.NewString("TestStore-master-db"), + ConnMaxLifetimeMilliseconds: model.NewInt(2), + }, + } + + driver := &plugintest.Driver{} + driver.On("Conn", true).Return("test", nil) + driver.On("ConnPing", "test").Return(nil) + driver.On("ConnClose", "test").Return(nil) + + api := &plugintest.API{} + defer api.AssertExpectations(t) + store := pluginapi.NewClient(api, driver).Store + + api.On("GetUnsanitizedConfig").Return(config) + masterDB, err := store.GetMasterDB() + require.NoError(t, err) + require.NotNil(t, masterDB) + + // No replica is set up, should fallback to master + replicaDB, err := store.GetReplicaDB() + require.NoError(t, err) + require.Same(t, replicaDB, masterDB) + + require.NoError(t, store.Close()) + }) + + t.Run("replica db singleton", func(t *testing.T) { + config := &model.Config{ + SqlSettings: model.SqlSettings{ + DriverName: model.NewString("ramsql"), + DataSource: model.NewString("TestStore-master-db"), + DataSourceReplicas: []string{"TestStore-master-db"}, + ConnMaxLifetimeMilliseconds: model.NewInt(2), + }, + } + + api := &plugintest.API{} + defer api.AssertExpectations(t) + api.On("GetUnsanitizedConfig").Return(config) + + driver := &plugintest.Driver{} + driver.On("Conn", true).Return("test", nil) + driver.On("Conn", false).Return("test", nil) + driver.On("ConnPing", "test").Return(nil) + driver.On("ConnClose", "test").Return(nil) + + store := pluginapi.NewClient(api, driver).Store + + db1, err := store.GetReplicaDB() + require.NoError(t, err) + require.NotNil(t, db1) + + db2, err := store.GetReplicaDB() + require.NoError(t, err) + require.NotNil(t, db2) + + require.Same(t, db1, db2) + require.NoError(t, store.Close()) + }) +} diff --git a/server/public/pluginapi/system.go b/server/public/pluginapi/system.go new file mode 100644 index 0000000000..e548d5d8a8 --- /dev/null +++ b/server/public/pluginapi/system.go @@ -0,0 +1,124 @@ +package pluginapi + +import ( + "net/url" + "path" + "time" + + "github.com/blang/semver/v4" + "github.com/pkg/errors" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" +) + +// SystemService exposes methods to query system properties. +type SystemService struct { + api plugin.API +} + +// GetManifest returns the manifest from the plugin bundle. +// +// Minimum server version: 5.10 +func (s *SystemService) GetManifest() (*model.Manifest, error) { + p, err := s.api.GetBundlePath() + if err != nil { + return nil, err + } + + m, _, err := model.FindManifest(p) + if err != nil { + return nil, errors.Wrap(err, "failed to find and open manifest") + } + + return m, nil +} + +// GetBundlePath returns the absolute path where the plugin's bundle was unpacked. +// +// Minimum server version: 5.10 +func (s *SystemService) GetBundlePath() (string, error) { + return s.api.GetBundlePath() +} + +// GetPluginAssetURL builds a URL to the given asset in the assets directory. +// Use this URL to link to assets from the webapp, or for third-party integrations with your plugin. +// +// Minimum server version: 5.2 +func (s *SystemService) GetPluginAssetURL(pluginID, asset string) (string, error) { + if pluginID == "" { + return "", errors.New("empty pluginID provided") + } + + if asset == "" { + return "", errors.New("empty asset name provided") + } + + siteURL := *s.api.GetConfig().ServiceSettings.SiteURL + if siteURL == "" { + return "", errors.New("no SiteURL configured by the server") + } + + u, err := url.Parse(siteURL + path.Join("/", pluginID, asset)) + if err != nil { + return "", err + } + + return u.String(), nil +} + +// GetLicense returns the current license used by the Mattermost server. Returns nil if the +// the server does not have a license. +// +// Minimum server version: 5.10 +func (s *SystemService) GetLicense() *model.License { + return s.api.GetLicense() +} + +// GetServerVersion return the current Mattermost server version +// +// Minimum server version: 5.4 +func (s *SystemService) GetServerVersion() string { + return s.api.GetServerVersion() +} + +// IsEnterpriseReady returns true if the Mattermost server is configured as Enterprise Ready. +// +// Minimum server version: 6.1 +func (s *SystemService) IsEnterpriseReady() bool { + return s.api.IsEnterpriseReady() +} + +// GetSystemInstallDate returns the time that Mattermost was first installed and ran. +// +// Minimum server version: 5.10 +func (s *SystemService) GetSystemInstallDate() (time.Time, error) { + installDateMS, appErr := s.api.GetSystemInstallDate() + installDate := time.Unix(0, installDateMS*int64(time.Millisecond)) + + return installDate, normalizeAppErr(appErr) +} + +// GetDiagnosticID returns a unique identifier used by the server for diagnostic reports. +// +// Minimum server version: 5.10 +func (s *SystemService) GetDiagnosticID() string { + // TODO: Consider deprecating/rewriting in favor of just using GetUnsanitizedConfig(). + return s.api.GetDiagnosticId() +} + +// RequestTrialLicense requests a trial license and installs it in the server. +// If the server version is lower than 5.36.0, an error is returned. +// +// Minimum server version: 5.36 +func (s *SystemService) RequestTrialLicense(requesterID string, users int, termsAccepted, receiveEmailsAccepted bool) error { + currentVersion := semver.MustParse(s.api.GetServerVersion()) + requiredVersion := semver.MustParse("5.36.0") + + if currentVersion.LT(requiredVersion) { + return errors.Errorf("current server version is lower than 5.36") + } + + err := s.api.RequestTrialLicense(requesterID, users, termsAccepted, receiveEmailsAccepted) + return normalizeAppErr(err) +} diff --git a/server/public/pluginapi/system_test.go b/server/public/pluginapi/system_test.go new file mode 100644 index 0000000000..5597365f71 --- /dev/null +++ b/server/public/pluginapi/system_test.go @@ -0,0 +1,104 @@ +package pluginapi_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/mattermost/mattermost/server/public/pluginapi" +) + +func TestGetManifest(t *testing.T) { + t.Run("valid manifest", func(t *testing.T) { + content := []byte(` + { + "id": "some.id", + "name": "Some Name" + } + `) + expectedManifest := &model.Manifest{ + Id: "some.id", + Name: "Some Name", + } + + dir, err := os.MkdirTemp("", "") + require.NoError(t, err) + defer os.RemoveAll(dir) + + tmpfn := filepath.Join(dir, "plugin.json") + //nolint:gosec //only used in tests + err = os.WriteFile(tmpfn, content, 0o666) + require.NoError(t, err) + + api := &plugintest.API{} + api.On("GetBundlePath").Return(dir, nil) + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + m, err := client.System.GetManifest() + require.NoError(t, err) + require.Equal(t, expectedManifest, m) + + // Altering the pointer doesn't alter the result + m.Id = "new.id" + + m2, err := client.System.GetManifest() + require.NoError(t, err) + require.Equal(t, expectedManifest, m2) + }) + + t.Run("GetBundlePath fails", func(t *testing.T) { + api := &plugintest.API{} + api.On("GetBundlePath").Return("", errors.New("")) + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + m, err := client.System.GetManifest() + require.Error(t, err) + require.Nil(t, m) + }) + + t.Run("No manifest found", func(t *testing.T) { + dir, err := os.MkdirTemp("", "") + require.NoError(t, err) + defer os.RemoveAll(dir) + + api := &plugintest.API{} + api.On("GetBundlePath").Return(dir, nil) + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + m, err := client.System.GetManifest() + require.Error(t, err) + require.Nil(t, m) + }) +} + +func TestRequestTrialLicense(t *testing.T) { + t.Run("Server version incompatible", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetServerVersion").Return("5.35.0") + err := client.System.RequestTrialLicense("requesterID", 10, true, true) + + require.Error(t, err) + require.Equal(t, "current server version is lower than 5.36", err.Error()) + }) + + t.Run("Server version compatible", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetServerVersion").Return("5.36.0") + api.On("RequestTrialLicense", "requesterID", 10, true, true).Return(nil) + + err := client.System.RequestTrialLicense("requesterID", 10, true, true) + + require.NoError(t, err) + }) +} diff --git a/server/public/pluginapi/team.go b/server/public/pluginapi/team.go new file mode 100644 index 0000000000..b3fb9f05cf --- /dev/null +++ b/server/public/pluginapi/team.go @@ -0,0 +1,231 @@ +package pluginapi + +import ( + "bytes" + "io" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" +) + +// TeamService exposes methods to manipulate teams and their members. +type TeamService struct { + api plugin.API +} + +// Get gets a team. +// +// Minimum server version: 5.2 +func (t *TeamService) Get(teamID string) (*model.Team, error) { + team, appErr := t.api.GetTeam(teamID) + + return team, normalizeAppErr(appErr) +} + +// GetByName gets a team by its name. +// +// Minimum server version: 5.2 +func (t *TeamService) GetByName(name string) (*model.Team, error) { + team, appErr := t.api.GetTeamByName(name) + + return team, normalizeAppErr(appErr) +} + +// TeamListOption is used to filter team listing. +type TeamListOption func(*ListTeamsOptions) + +// ListTeamsOptions holds options about filter out team listing. +type ListTeamsOptions struct { + UserID string +} + +// FilterTeamsByUser option is used to filter teams by user. +func FilterTeamsByUser(userID string) TeamListOption { + return func(o *ListTeamsOptions) { + o.UserID = userID + } +} + +// List gets a list of teams by options. +// +// Minimum server version: 5.2 +// Minimum server version when LimitTeamsToUser() option is used: 5.6 +func (t *TeamService) List(options ...TeamListOption) ([]*model.Team, error) { + opts := ListTeamsOptions{} + for _, o := range options { + o(&opts) + } + + var teams []*model.Team + var appErr *model.AppError + if opts.UserID != "" { + teams, appErr = t.api.GetTeamsForUser(opts.UserID) + } else { + teams, appErr = t.api.GetTeams() + } + + return teams, normalizeAppErr(appErr) +} + +// Search search a team. +// +// Minimum server version: 5.8 +func (t *TeamService) Search(term string) ([]*model.Team, error) { + teams, appErr := t.api.SearchTeams(term) + + return teams, normalizeAppErr(appErr) +} + +// Create creates a team. +// +// Minimum server version: 5.2 +func (t *TeamService) Create(team *model.Team) error { + createdTeam, appErr := t.api.CreateTeam(team) + if appErr != nil { + return normalizeAppErr(appErr) + } + + *team = *createdTeam + + return nil +} + +// Update updates a team. +// +// Minimum server version: 5.2 +func (t *TeamService) Update(team *model.Team) error { + updatedTeam, appErr := t.api.UpdateTeam(team) + if appErr != nil { + return normalizeAppErr(appErr) + } + + *team = *updatedTeam + + return nil +} + +// Delete deletes a team. +// +// Minimum server version: 5.2 +func (t *TeamService) Delete(teamID string) error { + return normalizeAppErr(t.api.DeleteTeam(teamID)) +} + +// GetIcon gets the team icon. +// +// Minimum server version: 5.6 +func (t *TeamService) GetIcon(teamID string) (io.Reader, error) { + contentBytes, appErr := t.api.GetTeamIcon(teamID) + if appErr != nil { + return nil, normalizeAppErr(appErr) + } + + return bytes.NewReader(contentBytes), nil +} + +// SetIcon sets the team icon. +// +// Minimum server version: 5.6 +func (t *TeamService) SetIcon(teamID string, content io.Reader) error { + contentBytes, err := io.ReadAll(content) + if err != nil { + return err + } + + return normalizeAppErr(t.api.SetTeamIcon(teamID, contentBytes)) +} + +// DeleteIcon removes the team icon. +// +// Minimum server version: 5.6 +func (t *TeamService) DeleteIcon(teamID string) error { + return normalizeAppErr(t.api.RemoveTeamIcon(teamID)) +} + +// GetUsers lists users of the team. +// +// Minimum server version: 5.6 +func (t *TeamService) ListUsers(teamID string, page, count int) ([]*model.User, error) { + users, appErr := t.api.GetUsersInTeam(teamID, page, count) + + return users, normalizeAppErr(appErr) +} + +// ListUnreadForUser gets the unread message and mention counts for each team to which the given user belongs. +// +// Minimum server version: 5.6 +func (t *TeamService) ListUnreadForUser(userID string) ([]*model.TeamUnread, error) { + teamUnreads, appErr := t.api.GetTeamsUnreadForUser(userID) + + return teamUnreads, normalizeAppErr(appErr) +} + +// GetMember returns a specific membership. +// +// Minimum server version: 5.2 +func (t *TeamService) GetMember(teamID, userID string) (*model.TeamMember, error) { + teamMember, appErr := t.api.GetTeamMember(teamID, userID) + + return teamMember, normalizeAppErr(appErr) +} + +// ListMembers returns the memberships of a specific team. +// +// Minimum server version: 5.2 +func (t *TeamService) ListMembers(teamID string, page, perPage int) ([]*model.TeamMember, error) { + teamMembers, appErr := t.api.GetTeamMembers(teamID, page, perPage) + + return teamMembers, normalizeAppErr(appErr) +} + +// ListMembersForUser returns all team memberships for a user. +// +// Minimum server version: 5.10 +func (t *TeamService) ListMembersForUser(userID string, page, perPage int) ([]*model.TeamMember, error) { + teamMembers, appErr := t.api.GetTeamMembersForUser(userID, page, perPage) + + return teamMembers, normalizeAppErr(appErr) +} + +// CreateMember creates a team membership. +// +// Minimum server version: 5.2 +func (t *TeamService) CreateMember(teamID, userID string) (*model.TeamMember, error) { + teamMember, appErr := t.api.CreateTeamMember(teamID, userID) + + return teamMember, normalizeAppErr(appErr) +} + +// CreateMembers creates a team membership for all provided user ids. +// +// Minimum server version: 5.2 +func (t *TeamService) CreateMembers(teamID string, userIDs []string, requestorID string) ([]*model.TeamMember, error) { + teamMembers, appErr := t.api.CreateTeamMembers(teamID, userIDs, requestorID) + + return teamMembers, normalizeAppErr(appErr) +} + +// DeleteMember deletes a team membership. +// +// Minimum server version: 5.2 +func (t *TeamService) DeleteMember(teamID, userID, requestorID string) error { + return normalizeAppErr(t.api.DeleteTeamMember(teamID, userID, requestorID)) +} + +// UpdateMemberRoles updates the role for a team membership. +// +// Minimum server version: 5.2 +func (t *TeamService) UpdateMemberRoles(teamID, userID, newRoles string) (*model.TeamMember, error) { + teamMember, appErr := t.api.UpdateTeamMemberRoles(teamID, userID, newRoles) + + return teamMember, normalizeAppErr(appErr) +} + +// GetStats gets a team's statistics +// +// Minimum server version: 5.8 +func (t *TeamService) GetStats(teamID string) (*model.TeamStats, error) { + teamStats, appErr := t.api.GetTeamStats(teamID) + + return teamStats, normalizeAppErr(appErr) +} diff --git a/server/public/pluginapi/team_test.go b/server/public/pluginapi/team_test.go new file mode 100644 index 0000000000..4dfe6d79ff --- /dev/null +++ b/server/public/pluginapi/team_test.go @@ -0,0 +1,587 @@ +package pluginapi_test + +import ( + "bytes" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/mattermost/mattermost/server/public/pluginapi" +) + +func TestCreateTeam(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("CreateTeam", &model.Team{Name: "1"}).Return(&model.Team{Name: "1", Id: "2"}, nil) + + team := &model.Team{Name: "1"} + err := client.Team.Create(team) + require.NoError(t, err) + require.Equal(t, &model.Team{Name: "1", Id: "2"}, team) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("CreateTeam", &model.Team{Name: "1"}).Return(nil, appErr) + + team := &model.Team{Name: "1"} + err := client.Team.Create(team) + require.Equal(t, appErr, err) + require.Equal(t, &model.Team{Name: "1"}, team) + }) +} + +func TestGetTeam(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetTeam", "1").Return(&model.Team{Id: "2"}, nil) + + team, err := client.Team.Get("1") + require.NoError(t, err) + require.Equal(t, &model.Team{Id: "2"}, team) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("GetTeam", "1").Return(nil, appErr) + + team, err := client.Team.Get("1") + require.Equal(t, appErr, err) + require.Zero(t, team) + }) +} + +func TestGetTeamByName(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetTeamByName", "1").Return(&model.Team{Id: "2"}, nil) + + team, err := client.Team.GetByName("1") + require.NoError(t, err) + require.Equal(t, &model.Team{Id: "2"}, team) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("GetTeamByName", "1").Return(nil, appErr) + + team, err := client.Team.GetByName("1") + require.Equal(t, appErr, err) + require.Zero(t, team) + }) +} + +func TestUpdateTeam(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("UpdateTeam", &model.Team{Name: "1"}).Return(&model.Team{Name: "1", Id: "2"}, nil) + + team := &model.Team{Name: "1"} + err := client.Team.Update(team) + require.NoError(t, err) + require.Equal(t, &model.Team{Name: "1", Id: "2"}, team) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("UpdateTeam", &model.Team{Name: "1"}).Return(nil, appErr) + + team := &model.Team{Name: "1"} + err := client.Team.Update(team) + require.Equal(t, appErr, err) + require.Equal(t, &model.Team{Name: "1"}, team) + }) +} + +func TestListTeams(t *testing.T) { + t.Run("list all", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetTeams").Return([]*model.Team{{Id: "1"}, {Id: "2"}}, nil) + + teams, err := client.Team.List() + require.NoError(t, err) + require.Equal(t, []*model.Team{{Id: "1"}, {Id: "2"}}, teams) + }) + + t.Run("list scoped to user", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetTeamsForUser", "3").Return([]*model.Team{{Id: "1"}, {Id: "2"}}, nil) + + teams, err := client.Team.List(pluginapi.FilterTeamsByUser("3")) + require.NoError(t, err) + require.Equal(t, []*model.Team{{Id: "1"}, {Id: "2"}}, teams) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("GetTeams").Return(nil, appErr) + + teams, err := client.Team.List() + require.Equal(t, appErr, err) + require.Len(t, teams, 0) + }) +} + +func TestSearchTeams(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("SearchTeams", "1").Return([]*model.Team{{Id: "1"}, {Id: "2"}}, nil) + + teams, err := client.Team.Search("1") + require.NoError(t, err) + require.Equal(t, []*model.Team{{Id: "1"}, {Id: "2"}}, teams) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("SearchTeams", "1").Return(nil, appErr) + + teams, err := client.Team.Search("1") + require.Equal(t, appErr, err) + require.Zero(t, teams) + }) +} + +func TestDeleteTeam(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("DeleteTeam", "1").Return(nil) + + err := client.Team.Delete("1") + require.NoError(t, err) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("DeleteTeam", "1").Return(appErr) + + err := client.Team.Delete("1") + require.Equal(t, appErr, err) + }) +} + +func TestGetTeamIcon(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetTeamIcon", "1").Return([]byte{2}, nil) + + content, err := client.Team.GetIcon("1") + require.NoError(t, err) + contentBytes, err := io.ReadAll(content) + require.NoError(t, err) + require.Equal(t, []byte{2}, contentBytes) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("GetTeamIcon", "1").Return(nil, appErr) + + content, err := client.Team.GetIcon("1") + require.Equal(t, appErr, err) + require.Zero(t, content) + }) +} + +func TestSetIcon(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("SetTeamIcon", "1", []byte{2}).Return(nil) + + err := client.Team.SetIcon("1", bytes.NewReader([]byte{2})) + require.NoError(t, err) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("SetTeamIcon", "1", []byte{2}).Return(appErr) + + err := client.Team.SetIcon("1", bytes.NewReader([]byte{2})) + require.Equal(t, appErr, err) + }) +} + +func TestDeleteIcon(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("RemoveTeamIcon", "1").Return(nil) + + err := client.Team.DeleteIcon("1") + require.NoError(t, err) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("RemoveTeamIcon", "1").Return(appErr) + + err := client.Team.DeleteIcon("1") + require.Equal(t, appErr, err) + }) +} + +func TestGetTeamUsers(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetUsersInTeam", "1", 2, 3).Return([]*model.User{{Id: "1"}, {Id: "2"}}, nil) + + users, err := client.Team.ListUsers("1", 2, 3) + require.NoError(t, err) + require.Equal(t, []*model.User{{Id: "1"}, {Id: "2"}}, users) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("GetUsersInTeam", "1", 2, 3).Return(nil, appErr) + + users, err := client.Team.ListUsers("1", 2, 3) + require.Equal(t, appErr, err) + require.Len(t, users, 0) + }) +} + +func TestGetTeamUnreads(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetTeamsUnreadForUser", "1").Return([]*model.TeamUnread{{TeamId: "1"}, {TeamId: "2"}}, nil) + + unreads, err := client.Team.ListUnreadForUser("1") + require.NoError(t, err) + require.Equal(t, []*model.TeamUnread{{TeamId: "1"}, {TeamId: "2"}}, unreads) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("GetTeamsUnreadForUser", "1").Return(nil, appErr) + + unreads, err := client.Team.ListUnreadForUser("1") + require.Equal(t, appErr, err) + require.Len(t, unreads, 0) + }) +} + +func TestCreateTeamMember(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("CreateTeamMember", "1", "2").Return(&model.TeamMember{TeamId: "3"}, nil) + + member, err := client.Team.CreateMember("1", "2") + require.NoError(t, err) + require.Equal(t, &model.TeamMember{TeamId: "3"}, member) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("CreateTeamMember", "1", "2").Return(nil, appErr) + + member, err := client.Team.CreateMember("1", "2") + require.Equal(t, appErr, err) + require.Zero(t, member) + }) +} + +func TestCreateTeamMembers(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("CreateTeamMembers", "1", []string{"2"}, "3").Return([]*model.TeamMember{{TeamId: "4"}, {TeamId: "5"}}, nil) + + members, err := client.Team.CreateMembers("1", []string{"2"}, "3") + require.NoError(t, err) + require.Equal(t, []*model.TeamMember{{TeamId: "4"}, {TeamId: "5"}}, members) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("CreateTeamMembers", "1", []string{"2"}, "3").Return(nil, appErr) + + members, err := client.Team.CreateMembers("1", []string{"2"}, "3") + require.Equal(t, appErr, err) + require.Len(t, members, 0) + }) +} + +func TestGetTeamMember(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetTeamMember", "1", "2").Return(&model.TeamMember{TeamId: "3"}, nil) + + member, err := client.Team.GetMember("1", "2") + require.NoError(t, err) + require.Equal(t, &model.TeamMember{TeamId: "3"}, member) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("GetTeamMember", "1", "2").Return(nil, appErr) + + member, err := client.Team.GetMember("1", "2") + require.Equal(t, appErr, err) + require.Zero(t, member) + }) +} + +func TestGetTeamMembers(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetTeamMembers", "1", 2, 3).Return([]*model.TeamMember{{TeamId: "4"}, {TeamId: "5"}}, nil) + + members, err := client.Team.ListMembers("1", 2, 3) + require.NoError(t, err) + require.Equal(t, []*model.TeamMember{{TeamId: "4"}, {TeamId: "5"}}, members) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("GetTeamMembers", "1", 2, 3).Return(nil, appErr) + + members, err := client.Team.ListMembers("1", 2, 3) + require.Equal(t, appErr, err) + require.Len(t, members, 0) + }) +} + +func TestGetUserMemberships(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetTeamMembersForUser", "1", 2, 3).Return([]*model.TeamMember{{TeamId: "4"}, {TeamId: "5"}}, nil) + + members, err := client.Team.ListMembersForUser("1", 2, 3) + require.NoError(t, err) + require.Equal(t, []*model.TeamMember{{TeamId: "4"}, {TeamId: "5"}}, members) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("GetTeamMembersForUser", "1", 2, 3).Return(nil, appErr) + + members, err := client.Team.ListMembersForUser("1", 2, 3) + require.Equal(t, appErr, err) + require.Len(t, members, 0) + }) +} + +func TestUpdateTeamMemberRoles(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("UpdateTeamMemberRoles", "1", "2", "3").Return(&model.TeamMember{TeamId: "3"}, nil) + + membership, err := client.Team.UpdateMemberRoles("1", "2", "3") + require.NoError(t, err) + require.Equal(t, &model.TeamMember{TeamId: "3"}, membership) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("UpdateTeamMemberRoles", "1", "2", "3").Return(nil, appErr) + + membership, err := client.Team.UpdateMemberRoles("1", "2", "3") + require.Equal(t, appErr, err) + require.Zero(t, membership) + }) +} + +func TestDeleteTeamMember(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("DeleteTeamMember", "1", "2", "3").Return(nil) + + err := client.Team.DeleteMember("1", "2", "3") + require.NoError(t, err) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("DeleteTeamMember", "1", "2", "3").Return(appErr) + + err := client.Team.DeleteMember("1", "2", "3") + require.Equal(t, appErr, err) + }) +} + +func TestGetTeamStats(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("GetTeamStats", "1").Return(&model.TeamStats{TeamId: "3"}, nil) + + stats, err := client.Team.GetStats("1") + require.NoError(t, err) + require.Equal(t, &model.TeamStats{TeamId: "3"}, stats) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + appErr := model.NewAppError("here", "id", nil, "an error occurred", http.StatusInternalServerError) + + api.On("GetTeamStats", "1").Return(nil, appErr) + + stats, err := client.Team.GetStats("1") + require.Equal(t, appErr, err) + require.Zero(t, stats) + }) +} diff --git a/server/public/pluginapi/user.go b/server/public/pluginapi/user.go new file mode 100644 index 0000000000..bb6252dd6a --- /dev/null +++ b/server/public/pluginapi/user.go @@ -0,0 +1,246 @@ +package pluginapi + +import ( + "bytes" + "io" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" +) + +// UserService exposes methods to manipulate users. +type UserService struct { + api plugin.API +} + +// Get gets a user. +// +// Minimum server version: 5.2 +func (u *UserService) Get(userID string) (*model.User, error) { + user, appErr := u.api.GetUser(userID) + + return user, normalizeAppErr(appErr) +} + +// GetByEmail gets a user by their email address. +// +// Minimum server version: 5.2 +func (u *UserService) GetByEmail(email string) (*model.User, error) { + user, appErr := u.api.GetUserByEmail(email) + + return user, normalizeAppErr(appErr) +} + +// GetByUsername gets a user by their username. +// +// Minimum server version: 5.2 +func (u *UserService) GetByUsername(username string) (*model.User, error) { + user, appErr := u.api.GetUserByUsername(username) + + return user, normalizeAppErr(appErr) +} + +// List a list of users based on search options. +// +// Minimum server version: 5.10 +func (u *UserService) List(options *model.UserGetOptions) ([]*model.User, error) { + users, appErr := u.api.GetUsers(options) + + return users, normalizeAppErr(appErr) +} + +// ListByUsernames gets users by their usernames. +// +// Minimum server version: 5.6 +func (u *UserService) ListByUsernames(usernames []string) ([]*model.User, error) { + users, appErr := u.api.GetUsersByUsernames(usernames) + + return users, normalizeAppErr(appErr) +} + +// ListInChannel returns a page of users in a channel. Page counting starts at 0. +// The sortBy parameter can be: "username" or "status". +// +// Minimum server version: 5.6 +func (u *UserService) ListInChannel(channelID, sortBy string, page, perPage int) ([]*model.User, error) { + users, appErr := u.api.GetUsersInChannel(channelID, sortBy, page, perPage) + + return users, normalizeAppErr(appErr) +} + +// ListInTeam gets users in team. +// +// Minimum server version: 5.6 +func (u *UserService) ListInTeam(teamID string, page, perPage int) ([]*model.User, error) { + users, appErr := u.api.GetUsersInTeam(teamID, page, perPage) + + return users, normalizeAppErr(appErr) +} + +// Search returns a list of users based on some search criteria. +// +// Minimum server version: 5.6 +func (u *UserService) Search(search *model.UserSearch) ([]*model.User, error) { + users, appErr := u.api.SearchUsers(search) + + return users, normalizeAppErr(appErr) +} + +// Create creates a user. +// +// Minimum server version: 5.2 +func (u *UserService) Create(user *model.User) error { + createdUser, appErr := u.api.CreateUser(user) + if appErr != nil { + return normalizeAppErr(appErr) + } + + *user = *createdUser + + return nil +} + +// Update updates a user. +// +// Minimum server version: 5.2 +func (u *UserService) Update(user *model.User) error { + updatedUser, appErr := u.api.UpdateUser(user) + if appErr != nil { + return normalizeAppErr(appErr) + } + + *user = *updatedUser + + return nil +} + +// Delete deletes a user. +// +// Minimum server version: 5.2 +func (u *UserService) Delete(userID string) error { + appErr := u.api.DeleteUser(userID) + + return normalizeAppErr(appErr) +} + +// GetStatus will get a user's status. +// +// Minimum server version: 5.2 +func (u *UserService) GetStatus(userID string) (*model.Status, error) { + status, appErr := u.api.GetUserStatus(userID) + + return status, normalizeAppErr(appErr) +} + +// ListStatusesByIDs will return a list of user statuses based on the provided slice of user IDs. +// +// Minimum server version: 5.2 +func (u *UserService) ListStatusesByIDs(userIDs []string) ([]*model.Status, error) { + statuses, appErr := u.api.GetUserStatusesByIds(userIDs) + + return statuses, normalizeAppErr(appErr) +} + +// UpdateStatus will set a user's status until the user, or another integration/plugin, sets it back to online. +// The status parameter can be: "online", "away", "dnd", or "offline". +// +// Minimum server version: 5.2 +func (u *UserService) UpdateStatus(userID, status string) (*model.Status, error) { + rStatus, appErr := u.api.UpdateUserStatus(userID, status) + + return rStatus, normalizeAppErr(appErr) +} + +// UpdateActive deactivates or reactivates an user. +// +// Minimum server version: 5.8 +func (u *UserService) UpdateActive(userID string, active bool) error { + appErr := u.api.UpdateUserActive(userID, active) + + return normalizeAppErr(appErr) +} + +// GetProfileImage gets user's profile image. +// +// Minimum server version: 5.6 +func (u *UserService) GetProfileImage(userID string) (io.Reader, error) { + contentBytes, appErr := u.api.GetProfileImage(userID) + if appErr != nil { + return nil, normalizeAppErr(appErr) + } + + return bytes.NewReader(contentBytes), nil +} + +// SetProfileImage sets a user's profile image. +// +// Minimum server version: 5.6 +func (u *UserService) SetProfileImage(userID string, content io.Reader) error { + contentBytes, err := io.ReadAll(content) + if err != nil { + return err + } + + return normalizeAppErr(u.api.SetProfileImage(userID, contentBytes)) +} + +// HasPermissionTo check if the user has the permission at system scope. +// +// Minimum server version: 5.3 +func (u *UserService) HasPermissionTo(userID string, permission *model.Permission) bool { + return u.api.HasPermissionTo(userID, permission) +} + +// HasPermissionToTeam check if the user has the permission at team scope. +// +// Minimum server version: 5.3 +func (u *UserService) HasPermissionToTeam(userID, teamID string, permission *model.Permission) bool { + return u.api.HasPermissionToTeam(userID, teamID, permission) +} + +// HasPermissionToChannel check if the user has the permission at channel scope. +// +// Minimum server version: 5.3 +func (u *UserService) HasPermissionToChannel(userID, channelID string, permission *model.Permission) bool { + return u.api.HasPermissionToChannel(userID, channelID, permission) +} + +// RolesGrantPermission check if the specified roles grant the specified permission +// +// Minimum server version: 6.3 +func (u *UserService) RolesGrantPermission(roleNames []string, permissionID string) bool { + return u.api.RolesGrantPermission(roleNames, permissionID) +} + +// GetLDAPAttributes will return LDAP attributes for a user. +// The attributes parameter should be a list of attributes to pull. +// Returns a map with attribute names as keys and the user's attributes as values. +// Requires an enterprise license, LDAP to be configured and for the user to use LDAP as an authentication method. +// +// Minimum server version: 5.3 +func (u *UserService) GetLDAPAttributes(userID string, attributes []string) (map[string]string, error) { + ldapUserAttributes, appErr := u.api.GetLDAPUserAttributes(userID, attributes) + + return ldapUserAttributes, normalizeAppErr(appErr) +} + +// CreateAccessToken creates a new access token. +// +// Minimum server version: 5.38 +func (u *UserService) CreateAccessToken(userID, description string) (*model.UserAccessToken, error) { + token := &model.UserAccessToken{ + UserId: userID, + Description: description, + } + + createdToken, appErr := u.api.CreateUserAccessToken(token) + + return createdToken, normalizeAppErr(appErr) +} + +// RevokeAccessToken revokes an existing access token. +// +// Minimum server version: 5.38 +func (u *UserService) RevokeAccessToken(tokenID string) error { + return normalizeAppErr(u.api.RevokeUserAccessToken(tokenID)) +} diff --git a/server/public/pluginapi/user_test.go b/server/public/pluginapi/user_test.go new file mode 100644 index 0000000000..beb5138a15 --- /dev/null +++ b/server/public/pluginapi/user_test.go @@ -0,0 +1,257 @@ +package pluginapi_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin/plugintest" + "github.com/mattermost/mattermost/server/public/pluginapi" +) + +func TestCreateUser(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + expectedUser := &model.User{ + Username: "test", + } + api.On("CreateUser", expectedUser).Return(expectedUser, nil) + + err := client.User.Create(expectedUser) + require.NoError(t, err) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + expectedUser := &model.User{ + Username: "test", + } + api.On("CreateUser", expectedUser).Return(nil, newAppError()) + + err := client.User.Create(expectedUser) + require.EqualError(t, err, "here: id, an error occurred") + }) +} + +func TestDeleteUser(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + expectedUserID := model.NewId() + api.On("DeleteUser", expectedUserID).Return(nil) + + err := client.User.Delete(expectedUserID) + require.NoError(t, err) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + expectedUserID := model.NewId() + api.On("DeleteUser", expectedUserID).Return(newAppError()) + + err := client.User.Delete(expectedUserID) + require.EqualError(t, err, "here: id, an error occurred") + }) +} + +func TestGetUsers(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + options := &model.UserGetOptions{} + expectedUsers := []*model.User{{Username: "test"}} + api.On("GetUsers", options).Return(expectedUsers, nil) + + actualUsers, err := client.User.List(options) + require.NoError(t, err) + assert.Equal(t, expectedUsers, actualUsers) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + options := &model.UserGetOptions{} + api.On("GetUsers", options).Return(nil, newAppError()) + + actualUsers, err := client.User.List(options) + require.EqualError(t, err, "here: id, an error occurred") + assert.Nil(t, actualUsers) + }) +} + +func TestGetUser(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + userID := "id" + expectedUser := &model.User{Id: userID, Username: "test"} + api.On("GetUser", userID).Return(expectedUser, nil) + + actualUser, err := client.User.Get(userID) + require.NoError(t, err) + assert.Equal(t, expectedUser, actualUser) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + userID := "id" + api.On("GetUser", userID).Return(nil, newAppError()) + + actualUser, err := client.User.Get(userID) + require.EqualError(t, err, "here: id, an error occurred") + assert.Nil(t, actualUser) + }) +} + +func TestGetUserByEmail(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + email := "test@example.com" + expectedUser := &model.User{Email: email, Username: "test"} + api.On("GetUserByEmail", email).Return(expectedUser, nil) + + actualUser, err := client.User.GetByEmail(email) + require.NoError(t, err) + assert.Equal(t, expectedUser, actualUser) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + email := "test@example.com" + api.On("GetUserByEmail", email).Return(nil, newAppError()) + + actualUser, err := client.User.GetByEmail(email) + require.EqualError(t, err, "here: id, an error occurred") + assert.Nil(t, actualUser) + }) +} + +func TestGetUserByUsername(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + username := "test" + expectedUser := &model.User{Username: username} + api.On("GetUserByUsername", username).Return(expectedUser, nil) + + actualUser, err := client.User.GetByUsername(username) + require.NoError(t, err) + assert.Equal(t, expectedUser, actualUser) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + username := "test" + api.On("GetUserByUsername", username).Return(nil, newAppError()) + + actualUser, err := client.User.GetByUsername(username) + require.EqualError(t, err, "here: id, an error occurred") + assert.Nil(t, actualUser) + }) +} + +func TestGetUsersByUsernames(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + usernames := []string{"test1", "test2"} + expectedUsers := []*model.User{{Username: "test1"}, {Username: "test2"}} + api.On("GetUsersByUsernames", usernames).Return(expectedUsers, nil) + + actualUsers, err := client.User.ListByUsernames(usernames) + require.NoError(t, err) + assert.Equal(t, expectedUsers, actualUsers) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + usernames := []string{"test1", "test2"} + api.On("GetUsersByUsernames", usernames).Return(nil, newAppError()) + + actualUsers, err := client.User.ListByUsernames(usernames) + require.EqualError(t, err, "here: id, an error occurred") + assert.Nil(t, actualUsers) + }) +} + +func TestGetUsersInTeam(t *testing.T) { + t.Run("success", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + teamID := "team_id" + page := 1 + perPage := 10 + expectedUsers := []*model.User{{Username: "test1"}, {Username: "test2"}} + api.On("GetUsersInTeam", teamID, page, perPage).Return(expectedUsers, nil) + + actualUsers, err := client.User.ListInTeam(teamID, page, perPage) + require.NoError(t, err) + assert.Equal(t, expectedUsers, actualUsers) + }) + + t.Run("failure", func(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + teamID := "team_id" + page := 1 + perPage := 10 + api.On("GetUsersInTeam", teamID, page, perPage).Return(nil, newAppError()) + + actualUsers, err := client.User.ListInTeam(teamID, page, perPage) + require.EqualError(t, err, "here: id, an error occurred") + assert.Nil(t, actualUsers) + }) +} + +func TestHasTeamUserPermission(t *testing.T) { + api := &plugintest.API{} + defer api.AssertExpectations(t) + client := pluginapi.NewClient(api, &plugintest.Driver{}) + + api.On("HasPermissionToTeam", "1", "2", &model.Permission{Id: "3"}).Return(true) + + ok := client.User.HasPermissionToTeam("1", "2", &model.Permission{Id: "3"}) + require.True(t, ok) +} diff --git a/server/public/pluginapi/utils.go b/server/public/pluginapi/utils.go new file mode 100644 index 0000000000..d1ff68ed34 --- /dev/null +++ b/server/public/pluginapi/utils.go @@ -0,0 +1,40 @@ +package pluginapi + +import ( + "time" +) + +func stringInSlice(a string, slice []string) bool { + for _, b := range slice { + if b == a { + return true + } + } + + return false +} + +var backoffTimeouts = []time.Duration{ + 50 * time.Millisecond, + 100 * time.Millisecond, + 200 * time.Millisecond, + 200 * time.Millisecond, + 400 * time.Millisecond, + 400 * time.Millisecond, +} + +// progressiveRetry executes a BackoffOperation and waits an increasing time before retrying the operation. +func progressiveRetry(operation func() error) error { + var err error + + for attempts := 0; attempts < len(backoffTimeouts); attempts++ { + err = operation() + if err == nil { + return nil + } + + time.Sleep(backoffTimeouts[attempts]) + } + + return err +} diff --git a/server/public/pluginapi/utils_test.go b/server/public/pluginapi/utils_test.go new file mode 100644 index 0000000000..7ee8d78656 --- /dev/null +++ b/server/public/pluginapi/utils_test.go @@ -0,0 +1,64 @@ +package pluginapi + +import ( + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestProgressiveRetry(t *testing.T) { + var retries int + + type args struct { + operation func() error + } + tests := []struct { + name string + args args + wantErr bool + expectedRetries int + }{ + { + name: "Should fail and return error", + args: args{ + operation: func() error { + retries++ + return errors.New("Operation Failed") + }, + }, + wantErr: true, + expectedRetries: 6, + }, + { + name: "Should succeed after two retries", + args: args{ + operation: func() error { + retries++ + if retries == 2 { + return nil + } + + return errors.New("Operation Failed") + }, + }, + wantErr: false, + expectedRetries: 2, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + retries = 0 + + err := progressiveRetry(tt.args.operation) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + + assert.Equal(t, tt.expectedRetries, retries) + }) + } +} diff --git a/server/platform/shared/driver/conn.go b/server/public/shared/driver/conn.go similarity index 100% rename from server/platform/shared/driver/conn.go rename to server/public/shared/driver/conn.go diff --git a/server/platform/shared/driver/driver.go b/server/public/shared/driver/driver.go similarity index 100% rename from server/platform/shared/driver/driver.go rename to server/public/shared/driver/driver.go diff --git a/server/platform/shared/driver/objects.go b/server/public/shared/driver/objects.go similarity index 100% rename from server/platform/shared/driver/objects.go rename to server/public/shared/driver/objects.go