Merge branch 'master' into mark-as-unread

This commit is contained in:
Harrison Healey
2019-10-18 16:07:41 -04:00
20 changed files with 522 additions and 273 deletions

View File

@@ -20,9 +20,7 @@ jobs:
- checkout
- run: |
cd ../
mkdir -p ~/.ssh/
echo -e "Host github.com\n\tStrictHostKeyChecking no\n" > ~/.ssh/config
git clone git@github.com:mattermost/mattermost-webapp.git
GIT_SSH_COMMAND="ssh -o StrictHostKeyChecking=no" git clone --depth=1 git@github.com:mattermost/mattermost-webapp.git
cd mattermost-webapp
git checkout $CIRCLE_BRANCH || git checkout master
export WEBAPP_GIT_COMMIT=$(git rev-parse HEAD)

View File

@@ -78,7 +78,7 @@ TESTFLAGS ?= -short
TESTFLAGSEE ?= -short
# Packages lists
TE_PACKAGES=$(shell go list ./...)
TE_PACKAGES=$(shell $(GO) list ./...)
# Plugins Packages
PLUGIN_PACKAGES=mattermost-plugin-zoom-v1.1.1
@@ -104,7 +104,7 @@ else
IGNORE:=$(shell rm -f imports/imports.go)
endif
EE_PACKAGES=$(shell go list ./enterprise/...)
EE_PACKAGES=$(shell $(GO) list ./enterprise/...)
ifeq ($(BUILD_ENTERPRISE_READY),true)
ALL_PACKAGES=$(TE_PACKAGES) $(EE_PACKAGES)
@@ -154,7 +154,7 @@ govet: ## Runs govet against all packages.
env GO111MODULE=off $(GO) get golang.org/x/tools/go/analysis/passes/shadow/cmd/shadow
$(GO) vet $(GOFLAGS) $(ALL_PACKAGES) || exit 1
$(GO) vet -vettool=$(GOPATH)/bin/shadow $(GOFLAGS) $(ALL_PACKAGES) || exit 1
$(GO) run plugin/checker/main.go
$(GO) run $(GOFLAGS) plugin/checker/main.go
gofmt: ## Runs gofmt against all packages.
@echo Running GOFMT

View File

@@ -11,6 +11,7 @@ import (
"github.com/mattermost/mattermost-server/mlog"
"github.com/mattermost/mattermost-server/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetPing(t *testing.T) {
@@ -66,24 +67,15 @@ func TestGetAudits(t *testing.T) {
audits, resp := th.SystemAdminClient.GetAudits(0, 100, "")
CheckNoError(t, resp)
if len(audits) == 0 {
t.Fatal("should not be empty")
}
require.NotEmpty(t, audits, "should not be empty")
audits, resp = th.SystemAdminClient.GetAudits(0, 1, "")
CheckNoError(t, resp)
if len(audits) != 1 {
t.Fatal("should only be 1")
}
require.Len(t, audits, 1, "should only be 1")
audits, resp = th.SystemAdminClient.GetAudits(1, 1, "")
CheckNoError(t, resp)
if len(audits) != 1 {
t.Fatal("should only be 1")
}
require.Len(t, audits, 1, "should only be 1")
_, resp = th.SystemAdminClient.GetAudits(-1, -1, "")
CheckNoError(t, resp)
@@ -222,17 +214,13 @@ func TestInvalidateCaches(t *testing.T) {
t.Run("as system user", func(t *testing.T) {
ok, resp := Client.InvalidateCaches()
CheckForbiddenStatus(t, resp)
if ok {
t.Fatal("should not clean the cache due no permission.")
}
require.False(t, ok, "should not clean the cache due to no permission.")
})
t.Run("as system admin", func(t *testing.T) {
ok, resp := th.SystemAdminClient.InvalidateCaches()
CheckNoError(t, resp)
if !ok {
t.Fatal("should clean the cache")
}
require.True(t, ok, "should clean the cache")
})
t.Run("as restricted system admin", func(t *testing.T) {
@@ -240,9 +228,7 @@ func TestInvalidateCaches(t *testing.T) {
ok, resp := th.SystemAdminClient.InvalidateCaches()
CheckForbiddenStatus(t, resp)
if ok {
t.Fatal("should not clean the cache due no permission.")
}
require.False(t, ok, "should not clean the cache due to no permission.")
})
}
@@ -257,29 +243,19 @@ func TestGetLogs(t *testing.T) {
logs, resp := th.SystemAdminClient.GetLogs(0, 10)
CheckNoError(t, resp)
require.Len(t, logs, 10)
if len(logs) != 10 {
t.Log(len(logs))
t.Fatal("wrong length")
}
for i := 10; i < 20; i++ {
assert.Containsf(t, logs[i-10], fmt.Sprintf(`"msg":"%d"`, i), "Log line doesn't contain correct message")
}
logs, resp = th.SystemAdminClient.GetLogs(1, 10)
CheckNoError(t, resp)
if len(logs) != 10 {
t.Log(len(logs))
t.Fatal("wrong length")
}
require.Len(t, logs, 10)
logs, resp = th.SystemAdminClient.GetLogs(-1, -1)
CheckNoError(t, resp)
if len(logs) == 0 {
t.Fatal("should not be empty")
}
require.NotEmpty(t, logs, "should not be empty")
_, resp = Client.GetLogs(0, 10)
CheckForbiddenStatus(t, resp)
@@ -319,9 +295,8 @@ func TestPostLog(t *testing.T) {
logMessage, resp := th.SystemAdminClient.PostLog(message)
CheckNoError(t, resp)
if len(logMessage) == 0 {
t.Fatal("should return the log message")
}
require.NotEmpty(t, logMessage, "should return the log message")
}
func TestGetAnalyticsOld(t *testing.T) {
@@ -331,10 +306,7 @@ func TestGetAnalyticsOld(t *testing.T) {
rows, resp := Client.GetAnalyticsOld("", "")
CheckForbiddenStatus(t, resp)
if rows != nil {
t.Fatal("should be nil")
}
require.Nil(t, rows, "should be nil")
rows, resp = th.SystemAdminClient.GetAnalyticsOld("", "")
CheckNoError(t, resp)
@@ -376,10 +348,7 @@ func TestGetAnalyticsOld(t *testing.T) {
assert.Equal(t, float64(0), rows2[5].Value)
WebSocketClient, err := th.CreateWebSocketClient()
if err != nil {
t.Fatal(err)
}
require.Nil(t, err)
rows2, resp2 = th.SystemAdminClient.GetAnalyticsOld("standard", "")
CheckNoError(t, resp2)
assert.Equal(t, "total_websocket_connections", rows2[5].Name)
@@ -433,10 +402,7 @@ func TestS3TestConnection(t *testing.T) {
t.Run("as system admin", func(t *testing.T) {
_, resp := th.SystemAdminClient.TestS3Connection(&config)
CheckBadRequestStatus(t, resp)
if resp.Error.Message != "S3 Bucket is required" {
t.Fatal("should return error - missing s3 bucket")
}
require.Equal(t, resp.Error.Message, "S3 Bucket is required", "should return error - missing s3 bucket")
// If this fails, check the test configuration to ensure minio is setup with the
// `mattermost-test` bucket defined by model.MINIO_BUCKET.
*config.FileSettings.AmazonS3Bucket = model.MINIO_BUCKET

View File

@@ -5,6 +5,7 @@ import (
"github.com/mattermost/mattermost-server/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetTermsOfService(t *testing.T) {
@@ -13,9 +14,7 @@ func TestGetTermsOfService(t *testing.T) {
Client := th.Client
_, err := th.App.CreateTermsOfService("abc", th.BasicUser.Id)
if err != nil {
t.Fatal(err)
}
require.Nil(t, err)
termsOfService, resp := Client.GetTermsOfService("")
CheckNoError(t, resp)

View File

@@ -223,6 +223,8 @@ func (a *App) RenameChannel(channel *model.Channel, newChannelName string, newDi
}
func (a *App) CreateChannel(channel *model.Channel, addMember bool) (*model.Channel, *model.AppError) {
channel.DisplayName = strings.TrimSpace(channel.DisplayName)
sc, err := a.Srv.Store.Channel().Save(channel, *a.Config().TeamSettings.MaxChannelsPerTeam)
if err != nil {
return nil, err

View File

@@ -254,6 +254,14 @@ func TestCreateChannelPrivateCreatesChannelMemberHistoryRecord(t *testing.T) {
assert.Equal(t, th.BasicUser.Id, histories[0].UserId)
assert.Equal(t, privateChannel.Id, histories[0].ChannelId)
}
func TestCreateChannelDisplayNameTrimsWhitespace(t *testing.T) {
th := Setup(t).InitBasic()
defer th.TearDown()
channel, err := th.App.CreateChannel(&model.Channel{DisplayName: " Public 1 ", Name: "public1", Type: model.CHANNEL_OPEN, TeamId: th.BasicTeam.Id}, false)
require.Nil(t, err)
require.Equal(t, channel.DisplayName, "Public 1")
}
func TestUpdateChannelPrivacy(t *testing.T) {
th := Setup(t).InitBasic()

View File

@@ -18,6 +18,7 @@ import (
"net/http"
"net/url"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
@@ -135,33 +136,8 @@ func (a *App) ListDirectory(path string) ([]string, *model.AppError) {
return *paths, nil
}
func (a *App) GetInfoForFilename(post *model.Post, teamId string, filename string) *model.FileInfo {
// Find the path from the Filename of the form /{channelId}/{userId}/{uid}/{nameWithExtension}
split := strings.SplitN(filename, "/", 5)
if len(split) < 5 {
mlog.Error(
"Unable to decipher filename when migrating post to use FileInfos",
mlog.String("post_id", post.Id),
mlog.String("filename", filename),
)
return nil
}
channelId := split[1]
userId := split[2]
oldId := split[3]
name, _ := url.QueryUnescape(split[4])
if split[0] != "" || split[1] != post.ChannelId || split[2] != post.UserId || strings.Contains(split[4], "/") {
mlog.Warn(
"Found an unusual filename when migrating post to use FileInfos",
mlog.String("post_id", post.Id),
mlog.String("channel_id", post.ChannelId),
mlog.String("user_id", post.UserId),
mlog.String("filename", filename),
)
}
func (a *App) getInfoForFilename(post *model.Post, teamId, channelId, userId, oldId, filename string) *model.FileInfo {
name, _ := url.QueryUnescape(filename)
pathPrefix := fmt.Sprintf("teams/%s/channels/%s/users/%s/%s/", teamId, channelId, userId, oldId)
path := pathPrefix + name
@@ -204,10 +180,8 @@ func (a *App) GetInfoForFilename(post *model.Post, teamId string, filename strin
return info
}
func (a *App) FindTeamIdForFilename(post *model.Post, filename string) string {
split := strings.SplitN(filename, "/", 5)
id := split[3]
name, _ := url.QueryUnescape(split[4])
func (a *App) findTeamIdForFilename(post *model.Post, id, filename string) string {
name, _ := url.QueryUnescape(filename)
// This post is in a direct channel so we need to figure out what team the files are stored under.
teams, err := a.Srv.Store.Team().GetTeamsByUserId(post.UserId)
@@ -223,7 +197,7 @@ func (a *App) FindTeamIdForFilename(post *model.Post, filename string) string {
for _, team := range teams {
path := fmt.Sprintf("teams/%s/channels/%s/users/%s/%s/%s", team.Id, post.ChannelId, post.UserId, id, name)
if _, err := a.ReadFile(path); err == nil {
if ok, err := a.FileExists(path); ok && err == nil {
// Found the team that this file was posted from
return team.Id
}
@@ -233,6 +207,27 @@ func (a *App) FindTeamIdForFilename(post *model.Post, filename string) string {
}
var fileMigrationLock sync.Mutex
var oldFilenameMatchExp *regexp.Regexp = regexp.MustCompile(`^\/([a-z\d]{26})\/([a-z\d]{26})\/([a-z\d]{26})\/([^\/]+)$`)
// Parse the path from the Filename of the form /{channelId}/{userId}/{uid}/{nameWithExtension}
func parseOldFilenames(filenames []string, channelId, userId string) [][]string {
parsed := [][]string{}
for _, filename := range filenames {
matches := oldFilenameMatchExp.FindStringSubmatch(filename)
if len(matches) != 5 {
mlog.Error("Failed to parse old Filename", mlog.String("filename", filename))
continue
}
if matches[1] != channelId {
mlog.Error("ChannelId in Filename does not match", mlog.String("channel_id", channelId), mlog.String("matched", matches[1]))
} else if matches[2] != userId {
mlog.Error("UserId in Filename does not match", mlog.String("user_id", userId), mlog.String("matched", matches[2]))
} else {
parsed = append(parsed, matches[1:])
}
}
return parsed
}
// Creates and stores FileInfos for a post created before the FileInfos table existed.
func (a *App) MigrateFilenamesToFileInfos(post *model.Post) []*model.FileInfo {
@@ -254,11 +249,19 @@ func (a *App) MigrateFilenamesToFileInfos(post *model.Post) []*model.FileInfo {
return []*model.FileInfo{}
}
// Parse and validate filenames before further processing
parsedFilenames := parseOldFilenames(filenames, post.ChannelId, post.UserId)
if len(parsedFilenames) == 0 {
mlog.Error("Unable to parse filenames")
return []*model.FileInfo{}
}
// Find the team that was used to make this post since its part of the file path that isn't saved in the Filename
var teamId string
if channel.TeamId == "" {
// This post was made in a cross-team DM channel, so we need to find where its files were saved
teamId = a.FindTeamIdForFilename(post, filenames[0])
teamId = a.findTeamIdForFilename(post, parsedFilenames[0][2], parsedFilenames[0][3])
} else {
teamId = channel.TeamId
}
@@ -272,8 +275,8 @@ func (a *App) MigrateFilenamesToFileInfos(post *model.Post) []*model.FileInfo {
mlog.String("post_id", post.Id),
)
} else {
for _, filename := range filenames {
info := a.GetInfoForFilename(post, teamId, filename)
for _, parsed := range parsedFilenames {
info := a.getInfoForFilename(post, teamId, parsed[0], parsed[1], parsed[2], parsed[3])
if info == nil {
continue
}

View File

@@ -53,7 +53,7 @@ func TestDoUploadFile(t *testing.T) {
}()
value := fmt.Sprintf("20070204/teams/%v/channels/%v/users/%v/%v/%v", teamId, channelId, userId, info1.Id, filename)
assert.Equal(t, value, info1.Path, "stored file at incorrect path" )
assert.Equal(t, value, info1.Path, "stored file at incorrect path")
info2, err := th.App.DoUploadFile(time.Date(2007, 2, 4, 1, 2, 3, 4, time.Local), teamId, channelId, userId, filename, data)
require.Nil(t, err, "DoUploadFile should succeed with valid data")
@@ -106,6 +106,103 @@ func TestUploadFile(t *testing.T) {
assert.Equal(t, value, info1.Path, "Stored file at incorrect path")
}
func TestParseOldFilenames(t *testing.T) {
th := Setup(t).InitBasic()
defer th.TearDown()
fileId := model.NewId()
tests := []struct {
description string
filenames []string
channelId string
userId string
expected [][]string
}{
{
description: "Empty input should result in empty output",
filenames: []string{},
channelId: th.BasicChannel.Id,
userId: th.BasicUser.Id,
expected: [][]string{},
},
{
description: "Filename with invalid format should not parse",
filenames: []string{"/path/to/some/file.png"},
channelId: th.BasicChannel.Id,
userId: th.BasicUser.Id,
expected: [][]string{},
},
{
description: "ChannelId in Filename should not match",
filenames: []string{
fmt.Sprintf("/%v/%v/%v/file.png", model.NewId(), th.BasicUser.Id, fileId),
},
channelId: th.BasicChannel.Id,
userId: th.BasicUser.Id,
expected: [][]string{},
},
{
description: "UserId in Filename should not match",
filenames: []string{
fmt.Sprintf("/%v/%v/%v/file.png", th.BasicChannel.Id, model.NewId(), fileId),
},
channelId: th.BasicChannel.Id,
userId: th.BasicUser.Id,
expected: [][]string{},
},
{
description: "../ in filename should not parse",
filenames: []string{
fmt.Sprintf("/%v/%v/%v/../../../file.png", th.BasicChannel.Id, th.BasicUser.Id, fileId),
},
channelId: th.BasicChannel.Id,
userId: th.BasicUser.Id,
expected: [][]string{},
},
{
description: "Should only parse valid filenames",
filenames: []string{
fmt.Sprintf("/%v/%v/%v/../otherfile.png", th.BasicChannel.Id, th.BasicUser.Id, fileId),
fmt.Sprintf("/%v/%v/%v/file.png", th.BasicChannel.Id, th.BasicUser.Id, fileId),
},
channelId: th.BasicChannel.Id,
userId: th.BasicUser.Id,
expected: [][]string{
{
th.BasicChannel.Id,
th.BasicUser.Id,
fileId,
"file.png",
},
},
},
{
description: "Valid Filename should parse",
filenames: []string{
fmt.Sprintf("/%v/%v/%v/file.png", th.BasicChannel.Id, th.BasicUser.Id, fileId),
},
channelId: th.BasicChannel.Id,
userId: th.BasicUser.Id,
expected: [][]string{
{
th.BasicChannel.Id,
th.BasicUser.Id,
fileId,
"file.png",
},
},
},
}
for _, test := range tests {
t.Run(test.description, func(tt *testing.T) {
result := parseOldFilenames(test.filenames, test.channelId, test.userId)
require.Equal(tt, result, test.expected)
})
}
}
func TestGetInfoForFilename(t *testing.T) {
th := Setup(t).InitBasic()
defer th.TearDown()
@@ -113,10 +210,7 @@ func TestGetInfoForFilename(t *testing.T) {
post := th.BasicPost
teamId := th.BasicTeam.Id
info := th.App.GetInfoForFilename(post, teamId, "sometestfile")
assert.Nil(t, info, "Test bad filename")
info = th.App.GetInfoForFilename(post, teamId, "/somechannel/someuser/someid/somefile.png")
info := th.App.getInfoForFilename(post, teamId, post.ChannelId, post.UserId, "someid", "somefile.png")
assert.Nil(t, info, "Test non-existent file")
}
@@ -124,13 +218,13 @@ func TestFindTeamIdForFilename(t *testing.T) {
th := Setup(t).InitBasic()
defer th.TearDown()
teamId := th.App.FindTeamIdForFilename(th.BasicPost, fmt.Sprintf("/%v/%v/%v/blargh.png", th.BasicChannel.Id, th.BasicUser.Id, "someid"))
teamId := th.App.findTeamIdForFilename(th.BasicPost, "someid", "somefile.png")
assert.Equal(t, th.BasicTeam.Id, teamId)
_, err := th.App.CreateTeamWithUser(&model.Team{Email: th.BasicUser.Email, Name: "zz" + model.NewId(), DisplayName: "Joram's Test Team", Type: model.TEAM_OPEN}, th.BasicUser.Id)
require.Nil(t, err)
teamId = th.App.FindTeamIdForFilename(th.BasicPost, fmt.Sprintf("/%v/%v/%v/blargh.png", th.BasicChannel.Id, th.BasicUser.Id, "someid"))
teamId = th.App.findTeamIdForFilename(th.BasicPost, "someid", "somefile.png")
assert.Equal(t, "", teamId)
}
@@ -151,14 +245,21 @@ func TestMigrateFilenamesToFileInfos(t *testing.T) {
require.Nil(t, fileErr)
defer file.Close()
fpath := fmt.Sprintf("/teams/%v/channels/%v/users/%v/%v/test.png", th.BasicTeam.Id, th.BasicChannel.Id, th.BasicUser.Id, "someid")
fileId := model.NewId()
fpath := fmt.Sprintf("/teams/%v/channels/%v/users/%v/%v/test.png", th.BasicTeam.Id, th.BasicChannel.Id, th.BasicUser.Id, fileId)
_, err := th.App.WriteFile(file, fpath)
require.Nil(t, err)
rpost, err := th.App.CreatePost(&model.Post{UserId: th.BasicUser.Id, ChannelId: th.BasicChannel.Id, Filenames: []string{fmt.Sprintf("/%v/%v/%v/test.png", th.BasicChannel.Id, th.BasicUser.Id, "someid")}}, th.BasicChannel, false)
rpost, err := th.App.CreatePost(&model.Post{UserId: th.BasicUser.Id, ChannelId: th.BasicChannel.Id, Filenames: []string{fmt.Sprintf("/%v/%v/%v/test.png", th.BasicChannel.Id, th.BasicUser.Id, fileId)}}, th.BasicChannel, false)
require.Nil(t, err)
infos = th.App.MigrateFilenamesToFileInfos(rpost)
assert.Equal(t, 1, len(infos))
rpost, err = th.App.CreatePost(&model.Post{UserId: th.BasicUser.Id, ChannelId: th.BasicChannel.Id, Filenames: []string{fmt.Sprintf("/%v/%v/%v/../../test.png", th.BasicChannel.Id, th.BasicUser.Id, fileId)}}, th.BasicChannel, false)
require.Nil(t, err)
infos = th.App.MigrateFilenamesToFileInfos(rpost)
assert.Equal(t, 0, len(infos))
}
func TestCopyFileInfos(t *testing.T) {

View File

@@ -7,6 +7,8 @@ import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"path/filepath"
"strings"
@@ -404,6 +406,17 @@ func (api *PluginAPI) AddChannelMember(channelId, userId string) (*model.Channel
return api.app.AddChannelMember(userId, channel, userRequestorId, postRootId)
}
func (api *PluginAPI) AddUserToChannel(channelId, userId, asUserId string) (*model.ChannelMember, *model.AppError) {
postRootId := ""
channel, err := api.GetChannel(channelId)
if err != nil {
return nil, err
}
return api.app.AddChannelMember(userId, channel, asUserId, postRootId)
}
func (api *PluginAPI) GetChannelMember(channelId, userId string) (*model.ChannelMember, *model.AppError) {
return api.app.GetChannelMember(channelId, userId)
}
@@ -655,6 +668,19 @@ func (api *PluginAPI) GetPluginStatus(id string) (*model.PluginStatus, *model.Ap
return api.app.GetPluginStatus(id)
}
func (api *PluginAPI) InstallPlugin(file io.Reader, replace bool) (*model.Manifest, *model.AppError) {
if !*api.app.Config().PluginSettings.Enable || !*api.app.Config().PluginSettings.EnableUploads {
return nil, model.NewAppError("installPlugin", "app.plugin.upload_disabled.app_error", nil, "", http.StatusNotImplemented)
}
fileBuffer, err := ioutil.ReadAll(file)
if err != nil {
return nil, model.NewAppError("InstallPlugin", "api.plugin.upload.file.app_error", nil, "", http.StatusBadRequest)
}
return api.app.InstallPlugin(bytes.NewReader(fileBuffer), replace)
}
// KV Store Section
func (api *PluginAPI) KVSet(key string, value []byte) *model.AppError {

View File

@@ -21,6 +21,7 @@ import (
"github.com/mattermost/mattermost-server/plugin"
"github.com/mattermost/mattermost-server/services/mailservice"
"github.com/mattermost/mattermost-server/utils"
"github.com/mattermost/mattermost-server/utils/fileutils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -684,6 +685,43 @@ func TestPluginAPIGetPlugins(t *testing.T) {
assert.Equal(t, pluginManifests, plugins)
}
func TestPluginAPIInstallPlugin(t *testing.T) {
th := Setup(t).InitBasic()
defer th.TearDown()
api := th.SetupPluginAPI()
path, _ := fileutils.FindDir("tests")
tarData, err := ioutil.ReadFile(filepath.Join(path, "testplugin.tar.gz"))
require.NoError(t, err)
_, err = api.InstallPlugin(bytes.NewReader(tarData), true)
assert.NotNil(t, err, "should not allow upload if upload disabled")
assert.Equal(t, err.Error(), "installPlugin: Plugins and/or plugin uploads have been disabled., ")
th.App.UpdateConfig(func(cfg *model.Config) {
*cfg.PluginSettings.Enable = true
*cfg.PluginSettings.EnableUploads = true
})
manifest, err := api.InstallPlugin(bytes.NewReader(tarData), true)
defer os.RemoveAll("plugins/testplugin")
require.Nil(t, err)
assert.Equal(t, "testplugin", manifest.Id)
// Successfully installed
pluginsResp, err := api.GetPlugins()
require.Nil(t, err)
found := false
for _, m := range pluginsResp {
if m.Id == manifest.Id {
found = true
}
}
assert.True(t, found)
}
func TestPluginAPIGetTeamIcon(t *testing.T) {
th := Setup(t).InitBasic()
defer th.TearDown()
@@ -1401,3 +1439,15 @@ func TestPluginAPIGetUnsanitizedConfig(t *testing.T) {
assert.NotEqual(t, config.SqlSettings.DataSourceSearchReplicas[i], model.FAKE_SETTING)
}
}
func TestPluginAddUserToChannel(t *testing.T) {
th := Setup(t).InitBasic()
defer th.TearDown()
api := th.SetupPluginAPI()
member, err := api.AddUserToChannel(th.BasicChannel.Id, th.BasicUser.Id, th.BasicUser2.Id)
require.Nil(t, err)
require.NotNil(t, member)
require.Equal(t, th.BasicChannel.Id, member.ChannelId)
require.Equal(t, th.BasicUser.Id, member.UserId)
}

View File

@@ -215,8 +215,9 @@ func NewServer(options ...Option) (*Server, error) {
mlog.Info(fmt.Sprintf("Current version is %v (%v/%v/%v/%v)", model.CurrentVersion, model.BuildNumber, model.BuildDate, model.BuildHash, model.BuildHashEnterprise))
mlog.Info(fmt.Sprintf("Enterprise Enabled: %v", model.BuildEnterpriseReady))
pwd, _ := os.Getwd()
mlog.Info(fmt.Sprintf("Current working directory is %v", pwd))
mlog.Info("Printing current working", mlog.String("directory", pwd))
mlog.Info("Loaded config", mlog.String("source", s.configStore.String()))
s.checkPushNotificationServerUrl()
@@ -244,7 +245,7 @@ func NewServer(options ...Option) (*Server, error) {
}
if err := s.Store.Status().ResetAll(); err != nil {
mlog.Error(fmt.Sprint("Error to reset the server status.", err.Error()))
mlog.Error("Error to reset the server status.", mlog.Err(err))
}
if s.joinCluster && s.Cluster != nil {
@@ -310,7 +311,7 @@ func (s *Server) StopHTTPServer() {
didShutdown := false
for s.didFinishListen != nil && !didShutdown {
if err := s.Server.Shutdown(ctx); err != nil {
mlog.Warn(err.Error())
mlog.Warn("Unable to shutdown server", mlog.Err(err))
}
timer := time.NewTimer(time.Millisecond * 50)
select {
@@ -332,7 +333,7 @@ func (s *Server) Shutdown() error {
err := s.shutdownDiagnostics()
if err != nil {
mlog.Error(fmt.Sprintf("Unable to cleanly shutdown diagnostic client: %s", err))
mlog.Error("Unable to cleanly shutdown diagnostic client", mlog.Err(err))
}
s.StopHTTPServer()
@@ -502,7 +503,7 @@ func (s *Server) Start() error {
if *s.Config().ServiceSettings.Forward80To443 {
if host, port, err := net.SplitHostPort(addr); err != nil {
mlog.Error("Unable to setup forwarding: " + err.Error())
mlog.Error("Unable to setup forwarding", mlog.Err(err))
} else if port != "443" {
return fmt.Errorf(utils.T("api.server.start_server.forward80to443.enabled_but_listening_on_wrong_port"), port)
} else {
@@ -519,7 +520,7 @@ func (s *Server) Start() error {
go func() {
redirectListener, err := net.Listen("tcp", httpListenAddress)
if err != nil {
mlog.Error("Unable to setup forwarding: " + err.Error())
mlog.Error("Unable to setup forwarding", mlog.Err(err))
return
}
defer redirectListener.Close()
@@ -605,7 +606,7 @@ func (s *Server) Start() error {
}
if err != nil && err != http.ErrServerClosed {
mlog.Critical(fmt.Sprintf("Error starting server, err:%v", err))
mlog.Critical("Error starting server", mlog.Err(err))
time.Sleep(time.Second)
}

View File

@@ -190,6 +190,7 @@ const (
PLUGIN_SETTINGS_DEFAULT_CLIENT_DIRECTORY = "./client/plugins"
PLUGIN_SETTINGS_DEFAULT_ENABLE_MARKETPLACE = true
PLUGIN_SETTINGS_DEFAULT_MARKETPLACE_URL = "https://api.integrations.mattermost.com"
PLUGIN_SETTINGS_OLD_MARKETPLACE_URL = "https://marketplace.integrations.mattermost.com"
COMPLIANCE_EXPORT_TYPE_CSV = "csv"
COMPLIANCE_EXPORT_TYPE_ACTIANCE = "actiance"
@@ -2284,7 +2285,7 @@ func (s *PluginSettings) SetDefaults(ls LogSettings) {
s.EnableMarketplace = NewBool(PLUGIN_SETTINGS_DEFAULT_ENABLE_MARKETPLACE)
}
if s.MarketplaceUrl == nil || *s.MarketplaceUrl == "" {
if s.MarketplaceUrl == nil || *s.MarketplaceUrl == "" || *s.MarketplaceUrl == PLUGIN_SETTINGS_OLD_MARKETPLACE_URL {
s.MarketplaceUrl = NewString(PLUGIN_SETTINGS_DEFAULT_MARKETPLACE_URL)
}
}

View File

@@ -1155,3 +1155,37 @@ func TestConfigSanitize(t *testing.T) {
assert.Equal(t, FAKE_SETTING, c.SqlSettings.DataSourceReplicas[0])
assert.Equal(t, FAKE_SETTING, c.SqlSettings.DataSourceSearchReplicas[0])
}
func TestConfigMarketplaceDefaults(t *testing.T) {
t.Parallel()
t.Run("no marketplace url", func(t *testing.T) {
c := Config{}
c.SetDefaults()
require.True(t, *c.PluginSettings.EnableMarketplace)
require.Equal(t, PLUGIN_SETTINGS_DEFAULT_MARKETPLACE_URL, *c.PluginSettings.MarketplaceUrl)
})
t.Run("old marketplace url", func(t *testing.T) {
c := Config{}
c.SetDefaults()
*c.PluginSettings.MarketplaceUrl = PLUGIN_SETTINGS_OLD_MARKETPLACE_URL
c.SetDefaults()
require.True(t, *c.PluginSettings.EnableMarketplace)
require.Equal(t, PLUGIN_SETTINGS_DEFAULT_MARKETPLACE_URL, *c.PluginSettings.MarketplaceUrl)
})
t.Run("custom marketplace url", func(t *testing.T) {
c := Config{}
c.SetDefaults()
*c.PluginSettings.MarketplaceUrl = "https://marketplace.example.com"
c.SetDefaults()
require.True(t, *c.PluginSettings.EnableMarketplace)
require.Equal(t, "https://marketplace.example.com", *c.PluginSettings.MarketplaceUrl)
})
}

View File

@@ -4,6 +4,8 @@
package plugin
import (
"io"
plugin "github.com/hashicorp/go-plugin"
"github.com/mattermost/mattermost-server/model"
)
@@ -334,11 +336,18 @@ type API interface {
// Minimum server version: 5.10
SearchPostsInTeam(teamId string, paramsList []*model.SearchParams) ([]*model.Post, *model.AppError)
// AddChannelMember creates a channel membership for a user.
// AddChannelMember joins a user to a channel (as if they joined themselves)
// This means the user will not receive notifications for joining the channel.
//
// Minimum server version: 5.2
AddChannelMember(channelId, userId string) (*model.ChannelMember, *model.AppError)
// AddUserToChannel adds a user to a channel as if the specified user had invited them.
// This means the user will receive the regular notifications for being added to the channel.
//
// Minimum server version: 5.18
AddUserToChannel(channelId, userId, asUserId string) (*model.ChannelMember, *model.AppError)
// GetChannelMember gets a channel membership for a user.
//
// Minimum server version: 5.2
@@ -557,6 +566,12 @@ type API interface {
// Minimum server version: 5.6
GetPluginStatus(id string) (*model.PluginStatus, *model.AppError)
// InstallPlugin will upload another plugin with tar.gz file.
// Previous version will be replaced on replace true.
//
// Minimum server version: 5.18
InstallPlugin(file io.Reader, replace bool) (*model.Manifest, *model.AppError)
// KV Store Section
// KVSet stores a key-value pair, unique per plugin.

View File

@@ -4,6 +4,7 @@
package main
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
@@ -35,6 +36,11 @@ func TestRunCheck(t *testing.T) {
},
}
// Enable debug flag to have packagesdriver/sizes.go print stderr of `go list` command.
// We want to surface any error text that may exist in stderr of this command.
prevEnvValue := os.Getenv("GOPACKAGESPRINTGOLISTERRORS")
os.Setenv("GOPACKAGESPRINTGOLISTERRORS", "true")
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := runCheck(tc.pkgPath)
@@ -46,4 +52,5 @@ func TestRunCheck(t *testing.T) {
}
})
}
os.Setenv("GOPACKAGESPRINTGOLISTERRORS", prevEnvValue)
}

View File

@@ -8,6 +8,7 @@ package plugin
import (
"fmt"
"io"
"log"
"github.com/mattermost/mattermost-server/mlog"
@@ -2281,6 +2282,37 @@ func (s *apiRPCServer) AddChannelMember(args *Z_AddChannelMemberArgs, returns *Z
return nil
}
type Z_AddUserToChannelArgs struct {
A string
B string
C string
}
type Z_AddUserToChannelReturns struct {
A *model.ChannelMember
B *model.AppError
}
func (g *apiRPCClient) AddUserToChannel(channelId, userId, asUserId string) (*model.ChannelMember, *model.AppError) {
_args := &Z_AddUserToChannelArgs{channelId, userId, asUserId}
_returns := &Z_AddUserToChannelReturns{}
if err := g.client.Call("Plugin.AddUserToChannel", _args, _returns); err != nil {
log.Printf("RPC call to AddUserToChannel API failed: %s", err.Error())
}
return _returns.A, _returns.B
}
func (s *apiRPCServer) AddUserToChannel(args *Z_AddUserToChannelArgs, returns *Z_AddUserToChannelReturns) error {
if hook, ok := s.impl.(interface {
AddUserToChannel(channelId, userId, asUserId string) (*model.ChannelMember, *model.AppError)
}); ok {
returns.A, returns.B = hook.AddUserToChannel(args.A, args.B, args.C)
} else {
return encodableError(fmt.Errorf("API AddUserToChannel called but not implemented."))
}
return nil
}
type Z_GetChannelMemberArgs struct {
A string
B string
@@ -3488,6 +3520,36 @@ func (s *apiRPCServer) GetPluginStatus(args *Z_GetPluginStatusArgs, returns *Z_G
return nil
}
type Z_InstallPluginArgs struct {
A io.Reader
B bool
}
type Z_InstallPluginReturns struct {
A *model.Manifest
B *model.AppError
}
func (g *apiRPCClient) InstallPlugin(file io.Reader, replace bool) (*model.Manifest, *model.AppError) {
_args := &Z_InstallPluginArgs{file, replace}
_returns := &Z_InstallPluginReturns{}
if err := g.client.Call("Plugin.InstallPlugin", _args, _returns); err != nil {
log.Printf("RPC call to InstallPlugin API failed: %s", err.Error())
}
return _returns.A, _returns.B
}
func (s *apiRPCServer) InstallPlugin(args *Z_InstallPluginArgs, returns *Z_InstallPluginReturns) error {
if hook, ok := s.impl.(interface {
InstallPlugin(file io.Reader, replace bool) (*model.Manifest, *model.AppError)
}); ok {
returns.A, returns.B = hook.InstallPlugin(args.A, args.B)
} else {
return encodableError(fmt.Errorf("API InstallPlugin called but not implemented."))
}
return nil
}
type Z_KVSetArgs struct {
A string
B []byte

View File

@@ -5,6 +5,8 @@
package plugintest
import (
io "io"
model "github.com/mattermost/mattermost-server/model"
mock "github.com/stretchr/testify/mock"
)
@@ -64,6 +66,31 @@ func (_m *API) AddReaction(reaction *model.Reaction) (*model.Reaction, *model.Ap
return r0, r1
}
// AddUserToChannel provides a mock function with given fields: channelId, userId, asUserId
func (_m *API) AddUserToChannel(channelId string, userId string, asUserId string) (*model.ChannelMember, *model.AppError) {
ret := _m.Called(channelId, userId, asUserId)
var r0 *model.ChannelMember
if rf, ok := ret.Get(0).(func(string, string, string) *model.ChannelMember); ok {
r0 = rf(channelId, userId, asUserId)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*model.ChannelMember)
}
}
var r1 *model.AppError
if rf, ok := ret.Get(1).(func(string, string, string) *model.AppError); ok {
r1 = rf(channelId, userId, asUserId)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(*model.AppError)
}
}
return r0, r1
}
// CopyFileInfos provides a mock function with given fields: userId, fileIds
func (_m *API) CopyFileInfos(userId string, fileIds []string) ([]string, *model.AppError) {
ret := _m.Called(userId, fileIds)
@@ -2783,3 +2810,28 @@ func (_m *API) UploadFile(data []byte, channelId string, filename string) (*mode
return r0, r1
}
// InstallPlugin provides a mock function with given fields: file, replace
func (_m *API) InstallPlugin(file io.Reader, replace bool) (*model.Manifest, *model.AppError) {
ret := _m.Called(file, replace)
var r0 *model.Manifest
if rf, ok := ret.Get(0).(func(io.Reader, bool) *model.Manifest); ok {
r0 = rf(file, replace)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*model.Manifest)
}
}
var r1 *model.AppError
if rf, ok := ret.Get(1).(func(io.Reader, bool) *model.AppError); ok {
r1 = rf(file, replace)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(*model.AppError)
}
}
return r0, r1
}

View File

@@ -12,11 +12,9 @@ func TestMapStringsToQueryParams(t *testing.T) {
keys, params := MapStringsToQueryParams(input, "Fruit")
if len(params) != 1 || params["Fruit0"] != "apple" {
t.Fatal("returned incorrect params", params)
} else if keys != "(:Fruit0)" {
t.Fatal("returned incorrect query", keys)
}
require.Len(t, params, 1, "returned incorrect params", params)
require.Equal(t, "apple", params["Fruit0"], "returned incorrect params", params)
require.Equal(t, "(:Fruit0)", keys, "returned incorrect query", keys)
})
t.Run("multiple items", func(t *testing.T) {
@@ -24,12 +22,11 @@ func TestMapStringsToQueryParams(t *testing.T) {
keys, params := MapStringsToQueryParams(input, "Vegetable")
if len(params) != 3 || params["Vegetable0"] != "carrot" ||
params["Vegetable1"] != "tomato" || params["Vegetable2"] != "potato" {
t.Fatal("returned incorrect params", params)
} else if keys != "(:Vegetable0,:Vegetable1,:Vegetable2)" {
t.Fatal("returned incorrect query", keys)
}
require.Len(t, params, 3, "returned incorrect params", params)
require.Equal(t, "carrot", params["Vegetable0"], "returned incorrect params", params)
require.Equal(t, "tomato", params["Vegetable1"], "returned incorrect params", params)
require.Equal(t, "potato", params["Vegetable2"], "returned incorrect params", params)
require.Equal(t, "(:Vegetable0,:Vegetable1,:Vegetable2)", keys, "returned incorrect query", keys)
})
}

View File

@@ -6,6 +6,8 @@ package storetest
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/mattermost/mattermost-server/model"
"github.com/mattermost/mattermost-server/store"
)
@@ -30,13 +32,11 @@ func testCommandStoreSave(t *testing.T, ss store.Store) {
o1.URL = "http://nowhere.com/"
o1.Trigger = "trigger"
if _, err := ss.Command().Save(&o1); err != nil {
t.Fatal("couldn't save item", err)
}
_, err := ss.Command().Save(&o1)
require.Nil(t, err, "couldn't save item")
if _, err := ss.Command().Save(&o1); err == nil {
t.Fatal("shouldn't be able to update from save")
}
_, err = ss.Command().Save(&o1)
require.NotNil(t, err, "shouldn't be able to update from save")
}
func testCommandStoreGet(t *testing.T, ss store.Store) {
@@ -48,21 +48,14 @@ func testCommandStoreGet(t *testing.T, ss store.Store) {
o1.Trigger = "trigger"
o1, err := ss.Command().Save(o1)
if err != nil {
t.Fatal(err)
}
require.Nil(t, err)
if r1, err := ss.Command().Get(o1.Id); err != nil {
t.Fatal(err)
} else {
if r1.CreateAt != o1.CreateAt {
t.Fatal("invalid returned command")
}
}
r1, err := ss.Command().Get(o1.Id)
require.Nil(t, err)
require.Equal(t, r1.CreateAt, o1.CreateAt, "invalid returned command")
if _, err := ss.Command().Get("123"); err == nil {
t.Fatal("Missing id should have failed")
}
_, err = ss.Command().Get("123")
require.NotNil(t, err, "Mising id should have failed")
}
func testCommandStoreGetByTeam(t *testing.T, ss store.Store) {
@@ -74,25 +67,16 @@ func testCommandStoreGetByTeam(t *testing.T, ss store.Store) {
o1.Trigger = "trigger"
o1, err := ss.Command().Save(o1)
if err != nil {
t.Fatal(err)
}
require.Nil(t, err)
if r1, err := ss.Command().GetByTeam(o1.TeamId); err != nil {
t.Fatal(err)
} else {
if r1[0].CreateAt != o1.CreateAt {
t.Fatal("invalid returned command")
}
}
r1, err := ss.Command().GetByTeam(o1.TeamId)
require.Nil(t, err)
require.NotEmpty(t, r1, "no command returned")
require.Equal(t, r1[0].CreateAt, o1.CreateAt, "invalid returned command")
if result, err := ss.Command().GetByTeam("123"); err != nil {
t.Fatal(err)
} else {
if len(result) != 0 {
t.Fatal("no commands should have returned")
}
}
result, err := ss.Command().GetByTeam("123")
require.Nil(t, err)
require.Empty(t, result, "no commands should have returned")
}
func testCommandStoreGetByTrigger(t *testing.T, ss store.Store) {
@@ -111,30 +95,21 @@ func testCommandStoreGetByTrigger(t *testing.T, ss store.Store) {
o2.Trigger = "trigger1"
o1, err := ss.Command().Save(o1)
if err != nil {
t.Fatal(err)
}
require.Nil(t, err)
_, err = ss.Command().Save(o2)
if err != nil {
t.Fatal(err)
}
require.Nil(t, err)
var r1 *model.Command
if r1, err = ss.Command().GetByTrigger(o1.TeamId, o1.Trigger); err != nil {
t.Fatal(err)
} else {
if r1.Id != o1.Id {
t.Fatal("invalid returned command")
}
}
r1, err = ss.Command().GetByTrigger(o1.TeamId, o1.Trigger)
require.Nil(t, err)
require.Equal(t, r1.Id, o1.Id, "invalid returned command")
err = ss.Command().Delete(o1.Id, model.GetMillis())
if err != nil {
t.Fatal(err)
}
require.Nil(t, err)
if _, err := ss.Command().GetByTrigger(o1.TeamId, o1.Trigger); err == nil {
t.Fatal("no commands should have returned")
}
_, err = ss.Command().GetByTrigger(o1.TeamId, o1.Trigger)
require.NotNil(t, err, "no commands should have returned")
}
func testCommandStoreDelete(t *testing.T, ss store.Store) {
@@ -146,26 +121,17 @@ func testCommandStoreDelete(t *testing.T, ss store.Store) {
o1.Trigger = "trigger"
o1, err := ss.Command().Save(o1)
if err != nil {
t.Fatal(err)
}
require.Nil(t, err)
if r1, err := ss.Command().Get(o1.Id); err != nil {
t.Fatal(err)
} else {
if r1.CreateAt != o1.CreateAt {
t.Fatal("invalid returned command")
}
}
r1, err := ss.Command().Get(o1.Id)
require.Nil(t, err)
require.Equal(t, r1.CreateAt, o1.CreateAt, "invalid returned command")
if err := ss.Command().Delete(o1.Id, model.GetMillis()); err != nil {
t.Fatal(err)
}
err = ss.Command().Delete(o1.Id, model.GetMillis())
require.Nil(t, err)
if r3, err := ss.Command().Get(o1.Id); err == nil {
t.Log(r3)
t.Fatal("Missing id should have failed")
}
_, err = ss.Command().Get(o1.Id)
require.NotNil(t, err, "Missing id should have failed")
}
func testCommandStoreDeleteByTeam(t *testing.T, ss store.Store) {
@@ -177,26 +143,17 @@ func testCommandStoreDeleteByTeam(t *testing.T, ss store.Store) {
o1.Trigger = "trigger"
o1, err := ss.Command().Save(o1)
if err != nil {
t.Fatal(err)
}
require.Nil(t, err)
if r1, err := ss.Command().Get(o1.Id); err != nil {
t.Fatal(err)
} else {
if r1.CreateAt != o1.CreateAt {
t.Fatal("invalid returned command")
}
}
r1, err := ss.Command().Get(o1.Id)
require.Nil(t, err)
require.Equal(t, r1.CreateAt, o1.CreateAt, "invalid returned command")
if err := ss.Command().PermanentDeleteByTeam(o1.TeamId); err != nil {
t.Fatal(err)
}
err = ss.Command().PermanentDeleteByTeam(o1.TeamId)
require.Nil(t, err)
if r3, err := ss.Command().Get(o1.Id); err == nil {
t.Log(r3)
t.Fatal("Missing id should have failed")
}
_, err = ss.Command().Get(o1.Id)
require.NotNil(t, err, "Missing id should have failed")
}
func testCommandStoreDeleteByUser(t *testing.T, ss store.Store) {
@@ -208,26 +165,17 @@ func testCommandStoreDeleteByUser(t *testing.T, ss store.Store) {
o1.Trigger = "trigger"
o1, err := ss.Command().Save(o1)
if err != nil {
t.Fatal(err)
}
require.Nil(t, err)
if r1, err := ss.Command().Get(o1.Id); err != nil {
t.Fatal(err)
} else {
if r1.CreateAt != o1.CreateAt {
t.Fatal("invalid returned command")
}
}
r1, err := ss.Command().Get(o1.Id)
require.Nil(t, err)
require.Equal(t, r1.CreateAt, o1.CreateAt, "invalid returned command")
if err := ss.Command().PermanentDeleteByUser(o1.CreatorId); err != nil {
t.Fatal(err)
}
err = ss.Command().PermanentDeleteByUser(o1.CreatorId)
require.Nil(t, err)
if r3, err := ss.Command().Get(o1.Id); err == nil {
t.Log(r3)
t.Fatal("Missing id should have failed")
}
_, err = ss.Command().Get(o1.Id)
require.NotNil(t, err, "Missing id should have failed")
}
func testCommandStoreUpdate(t *testing.T, ss store.Store) {
@@ -239,21 +187,17 @@ func testCommandStoreUpdate(t *testing.T, ss store.Store) {
o1.Trigger = "trigger"
o1, err := ss.Command().Save(o1)
if err != nil {
t.Fatal(err)
}
require.Nil(t, err)
o1.Token = model.NewId()
if _, err := ss.Command().Update(o1); err != nil {
t.Fatal(err)
}
_, err = ss.Command().Update(o1)
require.Nil(t, err)
o1.URL = "junk"
if _, err := ss.Command().Update(o1); err == nil {
t.Fatal("should have failed - bad URL")
}
_, err = ss.Command().Update(o1)
require.NotNil(t, err, "should have failed - bad URL")
}
func testCommandCount(t *testing.T, ss store.Store) {
@@ -265,23 +209,13 @@ func testCommandCount(t *testing.T, ss store.Store) {
o1.Trigger = "trigger"
o1, err := ss.Command().Save(o1)
if err != nil {
t.Fatal(err)
}
require.Nil(t, err)
if r1, err := ss.Command().AnalyticsCommandCount(""); err != nil {
t.Fatal(err)
} else {
if r1 == 0 {
t.Fatal("should be at least 1 command")
}
}
r1, err := ss.Command().AnalyticsCommandCount("")
require.Nil(t, err)
require.NotZero(t, r1, "should be at least 1 command")
if r2, err := ss.Command().AnalyticsCommandCount(o1.TeamId); err != nil {
t.Fatal(err)
} else {
if r2 != 1 {
t.Fatal("should be 1 command")
}
}
r2, err := ss.Command().AnalyticsCommandCount(o1.TeamId)
require.Nil(t, err)
require.Equal(t, r2, int64(1), "should be 1 command")
}

View File

@@ -21,19 +21,16 @@ func testLicenseStoreSave(t *testing.T, ss store.Store) {
l1.Id = model.NewId()
l1.Bytes = "junk"
if _, err := ss.License().Save(&l1); err != nil {
t.Fatal("couldn't save license record", err)
}
_, err := ss.License().Save(&l1)
require.Nil(t, err, "couldn't save license record")
if _, err := ss.License().Save(&l1); err != nil {
t.Fatal("shouldn't fail on trying to save existing license record", err)
}
_, err = ss.License().Save(&l1)
require.Nil(t, err, "shouldn't fail on trying to save existing license record")
l1.Id = ""
if _, err := ss.License().Save(&l1); err == nil {
t.Fatal("should fail on invalid license", err)
}
_, err = ss.License().Save(&l1)
require.NotNil(t, err, "should fail on invalid license")
}
func testLicenseStoreGet(t *testing.T, ss store.Store) {
@@ -44,15 +41,11 @@ func testLicenseStoreGet(t *testing.T, ss store.Store) {
_, err := ss.License().Save(&l1)
require.Nil(t, err)
if record, err := ss.License().Get(l1.Id); err != nil {
t.Fatal("couldn't get license", err)
} else {
if record.Bytes != l1.Bytes {
t.Fatal("license bytes didn't match")
}
}
record, err := ss.License().Get(l1.Id)
require.Nil(t, err, "couldn't get license")
if _, err := ss.License().Get("missing"); err == nil {
t.Fatal("should fail on get license", err)
}
require.Equal(t, record.Bytes, l1.Bytes, "license bytes didn't match")
_, err = ss.License().Get("missing")
require.NotNil(t, err, "should fail on get license")
}