diff --git a/api4/channel.go b/api4/channel.go index e3a5bf7034..335e198ca9 100644 --- a/api4/channel.go +++ b/api4/channel.go @@ -488,8 +488,13 @@ func getPinnedPosts(c *Context, w http.ResponseWriter, r *http.Request) { return } - w.Header().Set(model.HEADER_ETAG_SERVER, posts.Etag()) - w.Write([]byte(c.App.PostListWithProxyAddedToImageURLs(posts).ToJson())) + clientPostList, err := c.App.PreparePostListForClient(posts) + if err != nil { + mlog.Error("Failed to prepare posts for getFlaggedPostsForUser response", mlog.Any("err", err)) + } + + w.Header().Set(model.HEADER_ETAG_SERVER, clientPostList.Etag()) + w.Write([]byte(clientPostList.ToJson())) } func getPublicChannelsForTeam(c *Context, w http.ResponseWriter, r *http.Request) { diff --git a/api4/post.go b/api4/post.go index 7c116b7c77..508b076658 100644 --- a/api4/post.go +++ b/api4/post.go @@ -9,6 +9,7 @@ import ( "strconv" "time" + "github.com/mattermost/mattermost-server/mlog" "github.com/mattermost/mattermost-server/model" ) @@ -67,8 +68,13 @@ func createPost(c *Context, w http.ResponseWriter, r *http.Request) { c.App.SetStatusOnline(c.Session.UserId, false) c.App.UpdateLastActivityAtIfNeeded(c.Session) + clientPost, err := c.App.PreparePostForClient(rp) + if err != nil { + mlog.Error("Failed to prepare post for createPost response", mlog.Any("err", err)) + } + w.WriteHeader(http.StatusCreated) - w.Write([]byte(c.App.PostWithProxyAddedToImageURLs(rp).ToJson())) + w.Write([]byte(clientPost.ToJson())) } func createEphemeralPost(c *Context, w http.ResponseWriter, r *http.Request) { @@ -95,8 +101,13 @@ func createEphemeralPost(c *Context, w http.ResponseWriter, r *http.Request) { rp := c.App.SendEphemeralPost(ephRequest.UserID, c.App.PostWithProxyRemovedFromImageURLs(ephRequest.Post)) + clientPost, err := c.App.PreparePostForClient(rp) + if err != nil { + mlog.Error("Failed to prepare post for createEphemeralPost response", mlog.Any("err", err)) + } + w.WriteHeader(http.StatusCreated) - w.Write([]byte(c.App.PostWithProxyAddedToImageURLs(rp).ToJson())) + w.Write([]byte(clientPost.ToJson())) } func getPostsForChannel(c *Context, w http.ResponseWriter, r *http.Request) { @@ -165,7 +176,13 @@ func getPostsForChannel(c *Context, w http.ResponseWriter, r *http.Request) { if len(etag) > 0 { w.Header().Set(model.HEADER_ETAG_SERVER, etag) } - w.Write([]byte(c.App.PostListWithProxyAddedToImageURLs(list).ToJson())) + + clientPostList, err := c.App.PreparePostListForClient(list) + if err != nil { + mlog.Error("Failed to prepare posts for getPostsForChannel response", mlog.Any("err", err)) + } + + w.Write([]byte(clientPostList.ToJson())) } func getFlaggedPostsForUser(c *Context, w http.ResponseWriter, r *http.Request) { @@ -198,7 +215,12 @@ func getFlaggedPostsForUser(c *Context, w http.ResponseWriter, r *http.Request) return } - w.Write([]byte(c.App.PostListWithProxyAddedToImageURLs(posts).ToJson())) + clientPostList, err := c.App.PreparePostListForClient(posts) + if err != nil { + mlog.Error("Failed to prepare posts for getFlaggedPostsForUser response", mlog.Any("err", err)) + } + + w.Write([]byte(clientPostList.ToJson())) } func getPost(c *Context, w http.ResponseWriter, r *http.Request) { @@ -232,12 +254,17 @@ func getPost(c *Context, w http.ResponseWriter, r *http.Request) { } } + post, err = c.App.PreparePostForClient(post) + if err != nil { + mlog.Error("Failed to prepare post for getPost response", mlog.Any("err", err)) + } + if c.HandleEtag(post.Etag(), "Get Post", w, r) { return } w.Header().Set(model.HEADER_ETAG_SERVER, post.Etag()) - w.Write([]byte(c.App.PostWithProxyAddedToImageURLs(post).ToJson())) + w.Write([]byte(post.ToJson())) } func deletePost(c *Context, w http.ResponseWriter, r *http.Request) { @@ -315,8 +342,14 @@ func getPostThread(c *Context, w http.ResponseWriter, r *http.Request) { return } - w.Header().Set(model.HEADER_ETAG_SERVER, list.Etag()) - w.Write([]byte(c.App.PostListWithProxyAddedToImageURLs(list).ToJson())) + clientPostList, err := c.App.PreparePostListForClient(list) + if err != nil { + mlog.Error("Failed to prepare posts for getFlaggedPostsForUser response", mlog.Any("err", err)) + } + + w.Header().Set(model.HEADER_ETAG_SERVER, clientPostList.Etag()) + + w.Write([]byte(clientPostList.ToJson())) } func searchPosts(c *Context, w http.ResponseWriter, r *http.Request) { @@ -379,7 +412,12 @@ func searchPosts(c *Context, w http.ResponseWriter, r *http.Request) { return } - results = model.MakePostSearchResults(c.App.PostListWithProxyAddedToImageURLs(results.PostList), results.Matches) + clientPostList, err := c.App.PreparePostListForClient(results.PostList) + if err != nil { + mlog.Error("Failed to prepare posts for searchPosts response", mlog.Any("err", err)) + } + + results = model.MakePostSearchResults(clientPostList, results.Matches) w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") w.Write([]byte(results.ToJson())) @@ -430,7 +468,12 @@ func updatePost(c *Context, w http.ResponseWriter, r *http.Request) { return } - w.Write([]byte(c.App.PostWithProxyAddedToImageURLs(rpost).ToJson())) + rpost, err = c.App.PreparePostForClient(rpost) + if err != nil { + mlog.Error("Failed to prepare post for updatePost response", mlog.Any("err", err)) + } + + w.Write([]byte(rpost.ToJson())) } func patchPost(c *Context, w http.ResponseWriter, r *http.Request) { @@ -470,7 +513,12 @@ func patchPost(c *Context, w http.ResponseWriter, r *http.Request) { return } - w.Write([]byte(c.App.PostWithProxyAddedToImageURLs(patchedPost).ToJson())) + patchedPost, err = c.App.PreparePostForClient(patchedPost) + if err != nil { + mlog.Error("Failed to prepare post for patchPost response", mlog.Any("err", err)) + } + + w.Write([]byte(patchedPost.ToJson())) } func saveIsPinnedPost(c *Context, w http.ResponseWriter, r *http.Request, isPinned bool) { diff --git a/app/apptestlib.go b/app/apptestlib.go index dcc1fa9419..0286caf7c5 100644 --- a/app/apptestlib.go +++ b/app/apptestlib.go @@ -390,6 +390,39 @@ func (me *TestHelper) CreateScheme() (*model.Scheme, []*model.Role) { return scheme, roles } +func (me *TestHelper) CreateEmoji() *model.Emoji { + utils.DisableDebugLogForTest() + + result := <-me.App.Srv.Store.Emoji().Save(&model.Emoji{ + CreatorId: me.BasicUser.Id, + Name: model.NewRandomString(10), + }) + if result.Err != nil { + panic(result.Err) + } + + utils.EnableDebugLogForTest() + + return result.Data.(*model.Emoji) +} + +func (me *TestHelper) AddReactionToPost(post *model.Post, user *model.User, emojiName string) *model.Reaction { + utils.DisableDebugLogForTest() + + reaction, err := me.App.SaveReactionForPost(&model.Reaction{ + UserId: user.Id, + PostId: post.Id, + EmojiName: emojiName, + }) + if err != nil { + panic(err) + } + + utils.EnableDebugLogForTest() + + return reaction +} + func (me *TestHelper) TearDown() { me.App.Shutdown() os.Remove(me.tempConfigPath) diff --git a/app/emoji.go b/app/emoji.go index c0eda18a17..666558b2a3 100644 --- a/app/emoji.go +++ b/app/emoji.go @@ -185,6 +185,18 @@ func (a *App) GetEmojiByName(emojiName string) (*model.Emoji, *model.AppError) { return result.Data.(*model.Emoji), nil } +func (a *App) GetMultipleEmojiByName(names []string) ([]*model.Emoji, *model.AppError) { + if !*a.Config().ServiceSettings.EnableCustomEmoji { + return nil, model.NewAppError("GetMultipleEmojiByName", "api.emoji.disabled.app_error", nil, "", http.StatusNotImplemented) + } + + if result := <-a.Srv.Store.Emoji().GetMultipleByName(names); result.Err != nil { + return nil, result.Err + } else { + return result.Data.([]*model.Emoji), nil + } +} + func (a *App) GetEmojiImage(emojiId string) ([]byte, string, *model.AppError) { result := <-a.Srv.Store.Emoji().Get(emojiId, true) if result.Err != nil { diff --git a/app/notification.go b/app/notification.go index 54f1f470da..ea5bd899ba 100644 --- a/app/notification.go +++ b/app/notification.go @@ -317,8 +317,13 @@ func (a *App) SendNotifications(post *model.Post, team *model.Team, channel *mod } } + clientPost, err := a.PreparePostForClient(post) + if err != nil { + mlog.Error("Failed to prepare new post for client", mlog.Any("err", err)) + } + message := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_POSTED, "", post.ChannelId, "", nil) - message.Add("post", a.PostWithProxyAddedToImageURLs(post).ToJson()) + message.Add("post", clientPost.ToJson()) message.Add("channel_type", channel.Type) message.Add("channel_display_name", notification.GetChannelName(model.SHOW_USERNAME, "")) message.Add("channel_name", channel.Name) diff --git a/app/post.go b/app/post.go index c882fc0562..368053c3bf 100644 --- a/app/post.go +++ b/app/post.go @@ -301,8 +301,13 @@ func (a *App) SendEphemeralPost(userId string, post *model.Post) *model.Post { post.Props = model.StringInterface{} } + clientPost, err := a.PreparePostForClient(post) + if err != nil { + mlog.Error("Failed to prepare ephemeral post for client", mlog.Any("err", err)) + } + message := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_EPHEMERAL_MESSAGE, "", post.ChannelId, userId, nil) - message.Add("post", a.PostWithProxyAddedToImageURLs(post).ToJson()) + message.Add("post", clientPost.ToJson()) a.Publish(message) return post @@ -423,8 +428,13 @@ func (a *App) PatchPost(postId string, patch *model.PostPatch) (*model.Post, *mo } func (a *App) sendUpdatedPostEvent(post *model.Post) { + clientPost, err := a.PreparePostForClient(post) + if err != nil { + mlog.Error("Failed to prepare updated post for client", mlog.Any("err", err)) + } + message := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_POST_EDITED, "", post.ChannelId, "", nil) - message.Add("post", a.PostWithProxyAddedToImageURLs(post).ToJson()) + message.Add("post", clientPost.ToJson()) a.Publish(message) } @@ -563,8 +573,13 @@ func (a *App) DeletePost(postId, deleteByID string) (*model.Post, *model.AppErro return nil, result.Err } + clientPost, err := a.PreparePostForClient(post) + if err != nil { + mlog.Error("Failed to prepare deleted post for client", mlog.Any("err", err)) + } + message := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_POST_DELETED, "", post.ChannelId, "", nil) - message.Add("post", a.PostWithProxyAddedToImageURLs(post).ToJson()) + message.Add("post", clientPost.ToJson()) a.Publish(message) a.Go(func() { @@ -967,13 +982,6 @@ func (a *App) DoPostAction(postId, actionId, userId, selectedOption string) *mod return nil } -func (a *App) PostListWithProxyAddedToImageURLs(list *model.PostList) *model.PostList { - if f := a.ImageProxyAdder(); f != nil { - return list.WithRewrittenImageURLs(f) - } - return list -} - func (a *App) PostWithProxyAddedToImageURLs(post *model.Post) *model.Post { if f := a.ImageProxyAdder(); f != nil { return post.WithRewrittenImageURLs(f) diff --git a/app/post_metadata.go b/app/post_metadata.go new file mode 100644 index 0000000000..dd7e66774d --- /dev/null +++ b/app/post_metadata.go @@ -0,0 +1,111 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package app + +import ( + "strings" + + "github.com/dyatlov/go-opengraph/opengraph" + "github.com/mattermost/mattermost-server/model" +) + +func (a *App) PreparePostListForClient(originalList *model.PostList) (*model.PostList, *model.AppError) { + list := &model.PostList{ + Posts: make(map[string]*model.Post), + Order: originalList.Order, + } + + for id, originalPost := range originalList.Posts { + post, err := a.PreparePostForClient(originalPost) + if err != nil { + return originalList, err + } + + list.Posts[id] = post + } + + return list, nil +} + +func (a *App) PreparePostForClient(originalPost *model.Post) (*model.Post, *model.AppError) { + post := originalPost.Clone() + + var err *model.AppError + + needReactionCounts := post.ReactionCounts == nil + needEmojis := post.Emojis == nil + needImageDimensions := post.ImageDimensions == nil + needOpenGraphData := post.OpenGraphData == nil + + var reactions []*model.Reaction + if needReactionCounts || needEmojis { + reactions, err = a.GetReactionsForPost(post.Id) + if err != nil { + return post, err + } + } + + if needReactionCounts { + post.ReactionCounts = model.CountReactions(reactions) + } + + if post.FileInfos == nil { + fileInfos, err := a.GetFileInfosForPost(post.Id, false) + if err != nil { + return post, err + } + + post.FileInfos = fileInfos + } + + if needEmojis { + emojis, err := a.getCustomEmojisForPost(post.Message, reactions) + if err != nil { + return post, err + } + + post.Emojis = emojis + } + + post = a.PostWithProxyAddedToImageURLs(post) + + if needImageDimensions || needOpenGraphData { + if needImageDimensions { + post.ImageDimensions = []*model.PostImageDimensions{} + } + + if needOpenGraphData { + post.OpenGraphData = []*opengraph.OpenGraph{} + } + + // TODO + } + + return post, nil +} + +func (a *App) getCustomEmojisForPost(message string, reactions []*model.Reaction) ([]*model.Emoji, *model.AppError) { + if !*a.Config().ServiceSettings.EnableCustomEmoji { + // Only custom emoji are returned + return []*model.Emoji{}, nil + } + + names := model.EMOJI_PATTERN.FindAllString(message, -1) + + for _, reaction := range reactions { + names = append(names, reaction.EmojiName) + } + + if len(names) == 0 { + return []*model.Emoji{}, nil + } + + names = model.RemoveDuplicateStrings(names) + + for i, name := range names { + names[i] = strings.Trim(name, ":") + } + + return a.GetMultipleEmojiByName(names) +} diff --git a/app/post_metadata_test.go b/app/post_metadata_test.go new file mode 100644 index 0000000000..c1102a2424 --- /dev/null +++ b/app/post_metadata_test.go @@ -0,0 +1,356 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package app + +import ( + "fmt" + "testing" + "time" + + "github.com/mattermost/mattermost-server/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPreparePostForClient(t *testing.T) { + setup := func() *TestHelper { + th := Setup().InitBasic() + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.ImageProxyType = "" + *cfg.ServiceSettings.ImageProxyURL = "" + *cfg.ServiceSettings.ImageProxyOptions = "" + }) + + return th + } + + t.Run("no metadata needed", func(t *testing.T) { + th := setup() + defer th.TearDown() + + post := th.CreatePost(th.BasicChannel) + message := post.Message + + clientPost, err := th.App.PreparePostForClient(post) + require.Nil(t, err) + + assert.NotEqual(t, clientPost, post, "should've returned a new post") + assert.Equal(t, message, post.Message, "shouldn't have mutated post.Message") + assert.NotEqual(t, nil, post.ReactionCounts, "shouldn't have mutated post.ReactionCounts") + assert.NotEqual(t, nil, post.FileInfos, "shouldn't have mutated post.FileInfos") + assert.NotEqual(t, nil, post.Emojis, "shouldn't have mutated post.Emojis") + assert.NotEqual(t, nil, post.ImageDimensions, "shouldn't have mutated post.ImageDimensions") + assert.NotEqual(t, nil, post.OpenGraphData, "shouldn't have mutated post.OpenGraphData") + + assert.Equal(t, clientPost.Message, post.Message, "shouldn't have changed Message") + assert.Len(t, post.ReactionCounts, 0, "should've populated ReactionCounts") + assert.Len(t, post.FileInfos, 0, "should've populated FileInfos") + assert.Len(t, post.Emojis, 0, "should've populated Emojis") + assert.Len(t, post.ImageDimensions, 0, "should've populated ImageDimensions") + assert.Len(t, post.OpenGraphData, 0, "should've populated OpenGraphData") + }) + + t.Run("metadata already set", func(t *testing.T) { + th := setup() + defer th.TearDown() + + post, err := th.App.PreparePostForClient(th.CreatePost(th.BasicChannel)) + require.Nil(t, err) + + clientPost, err := th.App.PreparePostForClient(post) + require.Nil(t, err) + + assert.False(t, clientPost == post, "should've returned a new post") + assert.Equal(t, clientPost, post, "shouldn't have changed any metadata") + }) + + t.Run("reaction counts", func(t *testing.T) { + th := setup() + defer th.TearDown() + + post := th.CreatePost(th.BasicChannel) + th.AddReactionToPost(post, th.BasicUser, "smile") + + clientPost, err := th.App.PreparePostForClient(post) + require.Nil(t, err) + + assert.Equal(t, model.ReactionCounts{ + "smile": 1, + }, clientPost.ReactionCounts, "should've populated post.ReactionCounts") + }) + + t.Run("file infos", func(t *testing.T) { + th := setup() + defer th.TearDown() + + fileInfo, err := th.App.DoUploadFile(time.Now(), th.BasicTeam.Id, th.BasicChannel.Id, th.BasicUser.Id, "test.txt", []byte("test")) + require.Nil(t, err) + + post, err := th.App.CreatePost(&model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + FileIds: []string{fileInfo.Id}, + }, th.BasicChannel, false) + require.Nil(t, err) + + fileInfo.PostId = post.Id + + clientPost, err := th.App.PreparePostForClient(post) + require.Nil(t, err) + + assert.Equal(t, []*model.FileInfo{fileInfo}, clientPost.FileInfos, "should've populated post.FileInfos") + }) + + t.Run("emojis without custom emojis enabled", func(t *testing.T) { + th := setup() + defer th.TearDown() + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.EnableCustomEmoji = false + }) + + emoji := th.CreateEmoji() + + post, err := th.App.CreatePost(&model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: ":" + emoji.Name + ": :taco:", + }, th.BasicChannel, false) + require.Nil(t, err) + + th.AddReactionToPost(post, th.BasicUser, "smile") + th.AddReactionToPost(post, th.BasicUser, "angry") + th.AddReactionToPost(post, th.BasicUser2, "angry") + + clientPost, err := th.App.PreparePostForClient(post) + require.Nil(t, err) + + assert.Len(t, clientPost.ReactionCounts, 2, "should've populated post.ReactionCounts") + assert.Equal(t, 1, clientPost.ReactionCounts["smile"], "should've populated post.ReactionCounts for smile") + assert.Equal(t, 2, clientPost.ReactionCounts["angry"], "should've populated post.ReactionCounts for angry") + assert.ElementsMatch(t, []*model.Emoji{}, clientPost.Emojis, "should've populated empty post.Emojis") + }) + + t.Run("emojis with custom emojis enabled", func(t *testing.T) { + th := setup() + defer th.TearDown() + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.EnableCustomEmoji = true + }) + + emoji1 := th.CreateEmoji() + emoji2 := th.CreateEmoji() + emoji3 := th.CreateEmoji() + + post, err := th.App.CreatePost(&model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: ":" + emoji3.Name + ": :taco:", + }, th.BasicChannel, false) + require.Nil(t, err) + + th.AddReactionToPost(post, th.BasicUser, emoji1.Name) + th.AddReactionToPost(post, th.BasicUser, emoji2.Name) + th.AddReactionToPost(post, th.BasicUser2, emoji2.Name) + th.AddReactionToPost(post, th.BasicUser2, "angry") + + clientPost, err := th.App.PreparePostForClient(post) + require.Nil(t, err) + + assert.Len(t, clientPost.ReactionCounts, 3, "should've populated post.ReactionCounts") + assert.Equal(t, 1, clientPost.ReactionCounts[emoji1.Name], "should've populated post.ReactionCounts for emoji1") + assert.Equal(t, 2, clientPost.ReactionCounts[emoji2.Name], "should've populated post.ReactionCounts for emoji2") + assert.Equal(t, 1, clientPost.ReactionCounts["angry"], "should've populated post.ReactionCounts for angry") + assert.ElementsMatch(t, []*model.Emoji{emoji1, emoji2, emoji3}, clientPost.Emojis, "should've populated post.Emojis") + }) + + t.Run("linked image dimensions", func(t *testing.T) { + // TODO + }) + + t.Run("proxy linked images", func(t *testing.T) { + th := setup() + defer th.TearDown() + + testProxyLinkedImage(t, th, false) + }) + + t.Run("opengraph", func(t *testing.T) { + // TODO + }) + + t.Run("opengraph image dimensions", func(t *testing.T) { + // TODO + }) + + t.Run("proxy opengraph images", func(t *testing.T) { + // TODO + }) +} + +func TestPreparePostForClientWithImageProxy(t *testing.T) { + setup := func() *TestHelper { + th := Setup().InitBasic() + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.SiteURL = "http://mymattermost.com" + *cfg.ServiceSettings.ImageProxyType = "atmos/camo" + *cfg.ServiceSettings.ImageProxyURL = "https://127.0.0.1" + *cfg.ServiceSettings.ImageProxyOptions = "foo" + }) + + return th + } + + t.Run("proxy linked images", func(t *testing.T) { + th := setup() + defer th.TearDown() + + testProxyLinkedImage(t, th, true) + }) + + t.Run("proxy opengraph images", func(t *testing.T) { + // TODO + }) +} + +func testProxyLinkedImage(t *testing.T, th *TestHelper, shouldProxy bool) { + postTemplate := "![foo](%v)" + imageURL := "http://mydomain.com/myimage" + proxiedImageURL := "https://127.0.0.1/f8dace906d23689e8d5b12c3cefbedbf7b9b72f5/687474703a2f2f6d79646f6d61696e2e636f6d2f6d79696d616765" + + post := &model.Post{ + UserId: th.BasicUser.Id, + ChannelId: th.BasicChannel.Id, + Message: fmt.Sprintf(postTemplate, imageURL), + } + + var err *model.AppError + post, err = th.App.CreatePost(post, th.BasicChannel, false) + require.Nil(t, err) + + clientPost, err := th.App.PreparePostForClient(post) + require.Nil(t, err) + + if shouldProxy { + assert.Equal(t, post.Message, fmt.Sprintf(postTemplate, imageURL), "should not have mutated original post") + assert.Equal(t, clientPost.Message, fmt.Sprintf(postTemplate, proxiedImageURL), "should've replaced linked image URLs") + } else { + assert.Equal(t, clientPost.Message, fmt.Sprintf(postTemplate, imageURL), "shouldn't have replaced linked image URLs") + } +} + +func TestGetCustomEmojisForPost_Message(t *testing.T) { + th := Setup().InitBasic() + defer th.TearDown() + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.EnableCustomEmoji = true + }) + + emoji1 := th.CreateEmoji() + emoji2 := th.CreateEmoji() + emoji3 := th.CreateEmoji() + + testCases := []struct { + Description string + Input string + Expected []*model.Emoji + SkipExpectations bool + }{ + { + Description: "no emojis", + Input: "this is a string", + Expected: []*model.Emoji{}, + SkipExpectations: true, + }, + { + Description: "one emoji", + Input: "this is an :" + emoji1.Name + ": string", + Expected: []*model.Emoji{ + emoji1, + }, + }, + { + Description: "two emojis", + Input: "this is a :" + emoji3.Name + ": :" + emoji2.Name + ": string", + Expected: []*model.Emoji{ + emoji3, + emoji2, + }, + }, + { + Description: "punctuation around emojis", + Input: ":" + emoji3.Name + ":/:" + emoji1.Name + ": (:" + emoji2.Name + ":)", + Expected: []*model.Emoji{ + emoji3, + emoji1, + emoji2, + }, + }, + { + Description: "adjacent emojis", + Input: ":" + emoji3.Name + "::" + emoji1.Name + ":", + Expected: []*model.Emoji{ + emoji3, + emoji1, + }, + }, + { + Description: "duplicate emojis", + Input: "" + emoji1.Name + ": :" + emoji1.Name + ": :" + emoji1.Name + ": :" + emoji2.Name + ": :" + emoji2.Name + ": :" + emoji1.Name + ":", + Expected: []*model.Emoji{ + emoji1, + emoji2, + }, + }, + { + Description: "fake emojis", + Input: "these don't exist :tomato: :potato: :rotato:", + Expected: []*model.Emoji{}, + }, + { + Description: "fake and real emojis", + Input: ":tomato::" + emoji1.Name + ": :potato: :" + emoji2.Name + ":", + Expected: []*model.Emoji{ + emoji1, + emoji2, + }, + }, + } + + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.Description, func(t *testing.T) { + emojis, err := th.App.getCustomEmojisForPost(testCase.Input, nil) + assert.Nil(t, err, "failed to get emojis in message") + assert.ElementsMatch(t, emojis, testCase.Expected, "received incorrect emojis") + }) + } +} + +func TestGetCustomEmojisForPost(t *testing.T) { + th := Setup().InitBasic() + defer th.TearDown() + + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.EnableCustomEmoji = true + }) + + emoji1 := th.CreateEmoji() + emoji2 := th.CreateEmoji() + + reactions := []*model.Reaction{ + { + UserId: th.BasicUser.Id, + EmojiName: emoji1.Name, + }, + } + + emojis, err := th.App.getCustomEmojisForPost(":"+emoji2.Name+":", reactions) + assert.Nil(t, err, "failed to get emojis for post") + assert.ElementsMatch(t, emojis, []*model.Emoji{emoji1, emoji2}, "received incorrect emojis") +} diff --git a/app/post_test.go b/app/post_test.go index 5d93d3f0f5..2820d6e027 100644 --- a/app/post_test.go +++ b/app/post_test.go @@ -466,7 +466,6 @@ func TestImageProxy(t *testing.T) { list := model.NewPostList() list.Posts[post.Id] = post - assert.Equal(t, "![foo]("+tc.ProxiedImageURL+")", th.App.PostListWithProxyAddedToImageURLs(list).Posts[post.Id].Message) assert.Equal(t, "![foo]("+tc.ProxiedImageURL+")", th.App.PostWithProxyAddedToImageURLs(post).Message) assert.Equal(t, "![foo]("+tc.ImageURL+")", th.App.PostWithProxyRemovedFromImageURLs(post).Message) diff --git a/app/reaction.go b/app/reaction.go index 41fc7fca41..6167d77b40 100644 --- a/app/reaction.go +++ b/app/reaction.go @@ -6,6 +6,7 @@ package app import ( "net/http" + "github.com/mattermost/mattermost-server/mlog" "github.com/mattermost/mattermost-server/model" ) @@ -42,6 +43,9 @@ func (a *App) SaveReactionForPost(reaction *model.Reaction) (*model.Reaction, *m reaction = result.Data.(*model.Reaction) + // The post is always modified since the UpdateAt always changes + a.InvalidateCacheForChannelPosts(post.ChannelId) + a.Go(func() { a.sendReactionEvent(model.WEBSOCKET_EVENT_REACTION_ADDED, reaction, post, true) }) @@ -92,6 +96,9 @@ func (a *App) DeleteReactionForPost(reaction *model.Reaction) *model.AppError { return result.Err } + // The post is always modified since the UpdateAt always changes + a.InvalidateCacheForChannelPosts(post.ChannelId) + a.Go(func() { a.sendReactionEvent(model.WEBSOCKET_EVENT_REACTION_REMOVED, reaction, post, hasReactions) }) @@ -105,11 +112,15 @@ func (a *App) sendReactionEvent(event string, reaction *model.Reaction, post *mo message.Add("reaction", reaction.ToJson()) a.Publish(message) - // The post is always modified since the UpdateAt always changes - a.InvalidateCacheForChannelPosts(post.ChannelId) - post.HasReactions = hasReactions - post.UpdateAt = model.GetMillis() + clientPost, err := a.PreparePostForClient(post) + if err != nil { + mlog.Error("Failed to prepare new post for client after reaction", mlog.Any("err", err)) + } + + clientPost.HasReactions = hasReactions + clientPost.UpdateAt = model.GetMillis() + umessage := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_POST_EDITED, "", post.ChannelId, "", nil) - umessage.Add("post", a.PostWithProxyAddedToImageURLs(post).ToJson()) + umessage.Add("post", clientPost.ToJson()) a.Publish(umessage) } diff --git a/model/emoji.go b/model/emoji.go index f14af89df8..afe61493d7 100644 --- a/model/emoji.go +++ b/model/emoji.go @@ -7,6 +7,7 @@ import ( "encoding/json" "io" "net/http" + "regexp" ) const ( @@ -14,6 +15,8 @@ const ( EMOJI_SORT_BY_NAME = "name" ) +var EMOJI_PATTERN = regexp.MustCompile(`:[a-zA-Z0-9_-]+:`) + type Emoji struct { Id string `json:"id"` CreateAt int64 `json:"create_at"` diff --git a/model/post.go b/model/post.go index 5d2438fc4e..3bd74a1000 100644 --- a/model/post.go +++ b/model/post.go @@ -11,6 +11,7 @@ import ( "strings" "unicode/utf8" + "github.com/dyatlov/go-opengraph/opengraph" "github.com/mattermost/mattermost-server/utils/markdown" ) @@ -78,9 +79,22 @@ type Post struct { Props StringInterface `json:"props"` Hashtags string `json:"hashtags"` Filenames StringArray `json:"filenames,omitempty"` // Deprecated, do not use this field any more - FileIds StringArray `json:"file_ids,omitempty"` + FileIds StringArray `json:"file_ids,omitempty"` // Deprecated, do not use this field any more PendingPostId string `json:"pending_post_id" db:"-"` - HasReactions bool `json:"has_reactions,omitempty"` + HasReactions bool `json:"has_reactions,omitempty"` // Deprecated, do not use this field any more + + // Transient fields populated before sending posts to the client + ReactionCounts ReactionCounts `json:"reaction_counts" db:"-"` + FileInfos []*FileInfo `json:"file_infos" db:"-"` + ImageDimensions []*PostImageDimensions `json:"image_dimensions" db:"-"` + OpenGraphData []*opengraph.OpenGraph `json:"opengraph_data" db:"-"` + Emojis []*Emoji `json:"emojis" db:"-"` +} + +type PostImageDimensions struct { + URL string `json:"url"` + Width int64 `json:"width"` + Height int64 `json:"height"` } type PostEphemeral struct { @@ -170,10 +184,16 @@ type PostActionIntegrationResponse struct { EphemeralText string `json:"ephemeral_text"` } -func (o *Post) ToJson() string { +// Shallowly clone the a post +func (o *Post) Clone() *Post { copy := *o + return © +} + +func (o *Post) ToJson() string { + copy := o.Clone() copy.StripActionIntegrations() - b, _ := json.Marshal(©) + b, _ := json.Marshal(copy) return string(b) } @@ -502,12 +522,12 @@ var markdownDestinationEscaper = strings.NewReplacer( // WithRewrittenImageURLs returns a new shallow copy of the post where the message has been // rewritten via RewriteImageURLs. func (o *Post) WithRewrittenImageURLs(f func(string) string) *Post { - copy := *o + copy := o.Clone() copy.Message = RewriteImageURLs(o.Message, f) if copy.MessageSource == "" && copy.Message != o.Message { copy.MessageSource = o.Message } - return © + return copy } func (o *PostEphemeral) ToUnsanitizedJson() string { diff --git a/model/reaction.go b/model/reaction.go index c1b9c499a8..8eb0674d4d 100644 --- a/model/reaction.go +++ b/model/reaction.go @@ -17,6 +17,8 @@ type Reaction struct { CreateAt int64 `json:"create_at"` } +type ReactionCounts map[string]int + func (o *Reaction) ToJson() string { b, _ := json.Marshal(o) return string(b) @@ -74,3 +76,13 @@ func (o *Reaction) PreSave() { o.CreateAt = GetMillis() } } + +func CountReactions(reactions []*Reaction) ReactionCounts { + reactionCounts := ReactionCounts{} + + for _, reaction := range reactions { + reactionCounts[reaction.EmojiName] += 1 + } + + return reactionCounts +} diff --git a/model/reaction_test.go b/model/reaction_test.go index a357504775..26d3a6bd22 100644 --- a/model/reaction_test.go +++ b/model/reaction_test.go @@ -82,3 +82,38 @@ func TestReactionIsValid(t *testing.T) { t.Fatal("create at should be invalid") } } + +func TestCountReactions(t *testing.T) { + userId := NewId() + userId2 := NewId() + + reactions := []*Reaction{ + { + UserId: userId, + EmojiName: "smile", + }, + { + UserId: userId, + EmojiName: "frowning", + }, + { + UserId: userId2, + EmojiName: "smile", + }, + { + UserId: userId2, + EmojiName: "neutral_face", + }, + } + + reactionCounts := CountReactions(reactions) + if len(reactionCounts) != 3 { + t.Fatal("should've received counts for 3 reactions") + } else if reactionCounts["smile"] != 2 { + t.Fatal("should've received 2 smile reactions") + } else if reactionCounts["frowning"] != 1 { + t.Fatal("should've received 1 frowning reaction") + } else if reactionCounts["neutral_face"] != 1 { + t.Fatal("should've received 2 neutral_face reaction") + } +} diff --git a/store/local_cache_supplier_roles.go b/store/local_cache_supplier_roles.go index 41f88a216e..769b9d0d56 100644 --- a/store/local_cache_supplier_roles.go +++ b/store/local_cache_supplier_roles.go @@ -18,10 +18,11 @@ func (s *LocalCacheSupplier) handleClusterInvalidateRole(msg *model.ClusterMessa } func (s *LocalCacheSupplier) RoleSave(ctx context.Context, role *model.Role, hints ...LayeredStoreHint) *LayeredStoreSupplierResult { + result := s.Next().RoleSave(ctx, role, hints...) if len(role.Id) != 0 { - defer s.doInvalidateCacheCluster(s.roleCache, role.Name) + s.doInvalidateCacheCluster(s.roleCache, role.Name) } - return s.Next().RoleSave(ctx, role, hints...) + return result } func (s *LocalCacheSupplier) RoleGet(ctx context.Context, roleId string, hints ...LayeredStoreHint) *LayeredStoreSupplierResult { @@ -81,8 +82,10 @@ func (s *LocalCacheSupplier) RoleDelete(ctx context.Context, roleId string, hint } func (s *LocalCacheSupplier) RolePermanentDeleteAll(ctx context.Context, hints ...LayeredStoreHint) *LayeredStoreSupplierResult { - defer s.roleCache.Purge() - defer s.doClearCacheCluster(s.roleCache) + result := s.Next().RolePermanentDeleteAll(ctx, hints...) - return s.Next().RolePermanentDeleteAll(ctx, hints...) + s.roleCache.Purge() + s.doClearCacheCluster(s.roleCache) + + return result } diff --git a/store/local_cache_supplier_schemes.go b/store/local_cache_supplier_schemes.go index 8dd1fededf..5bf84e15b5 100644 --- a/store/local_cache_supplier_schemes.go +++ b/store/local_cache_supplier_schemes.go @@ -18,10 +18,11 @@ func (s *LocalCacheSupplier) handleClusterInvalidateScheme(msg *model.ClusterMes } func (s *LocalCacheSupplier) SchemeSave(ctx context.Context, scheme *model.Scheme, hints ...LayeredStoreHint) *LayeredStoreSupplierResult { + result := s.Next().SchemeSave(ctx, scheme, hints...) if len(scheme.Id) != 0 { - defer s.doInvalidateCacheCluster(s.schemeCache, scheme.Id) + s.doInvalidateCacheCluster(s.schemeCache, scheme.Id) } - return s.Next().SchemeSave(ctx, scheme, hints...) + return result } func (s *LocalCacheSupplier) SchemeGet(ctx context.Context, schemeId string, hints ...LayeredStoreHint) *LayeredStoreSupplierResult { @@ -41,10 +42,12 @@ func (s *LocalCacheSupplier) SchemeGetByName(ctx context.Context, schemeName str } func (s *LocalCacheSupplier) SchemeDelete(ctx context.Context, schemeId string, hints ...LayeredStoreHint) *LayeredStoreSupplierResult { - defer s.doInvalidateCacheCluster(s.schemeCache, schemeId) - defer s.doClearCacheCluster(s.roleCache) + result := s.Next().SchemeDelete(ctx, schemeId, hints...) - return s.Next().SchemeDelete(ctx, schemeId, hints...) + s.doInvalidateCacheCluster(s.schemeCache, schemeId) + s.doClearCacheCluster(s.roleCache) + + return result } func (s *LocalCacheSupplier) SchemeGetAllPage(ctx context.Context, scope string, offset int, limit int, hints ...LayeredStoreHint) *LayeredStoreSupplierResult { diff --git a/store/redis_supplier_roles.go b/store/redis_supplier_roles.go index a445c38a9f..b511a49933 100644 --- a/store/redis_supplier_roles.go +++ b/store/redis_supplier_roles.go @@ -13,6 +13,7 @@ import ( func (s *RedisSupplier) RoleSave(ctx context.Context, role *model.Role, hints ...LayeredStoreHint) *LayeredStoreSupplierResult { key := buildRedisKeyForRoleName(role.Name) + result := s.Next().RoleSave(ctx, role, hints...) defer func() { if err := s.client.Del(key).Err(); err != nil { @@ -20,7 +21,7 @@ func (s *RedisSupplier) RoleSave(ctx context.Context, role *model.Role, hints .. } }() - return s.Next().RoleSave(ctx, role, hints...) + return result } func (s *RedisSupplier) RoleGet(ctx context.Context, roleId string, hints ...LayeredStoreHint) *LayeredStoreSupplierResult { @@ -86,6 +87,7 @@ func (s *RedisSupplier) RoleGetByNames(ctx context.Context, roleNames []string, } func (s *RedisSupplier) RoleDelete(ctx context.Context, roleId string, hints ...LayeredStoreHint) *LayeredStoreSupplierResult { + // XXXXXX Shouldn't this call Role result := s.Next().RoleGet(ctx, roleId, hints...) if result.Err == nil { @@ -103,17 +105,17 @@ func (s *RedisSupplier) RoleDelete(ctx context.Context, roleId string, hints ... } func (s *RedisSupplier) RolePermanentDeleteAll(ctx context.Context, hints ...LayeredStoreHint) *LayeredStoreSupplierResult { - defer func() { - if keys, err := s.client.Keys("roles:*").Result(); err != nil { - mlog.Error("Redis encountered an error on read: " + err.Error()) - } else { - if err := s.client.Del(keys...).Err(); err != nil { - mlog.Error("Redis encountered an error on delete: " + err.Error()) - } - } - }() + result := s.Next().RolePermanentDeleteAll(ctx, hints...) - return s.Next().RolePermanentDeleteAll(ctx, hints...) + if keys, err := s.client.Keys("roles:*").Result(); err != nil { + mlog.Error("Redis encountered an error on read: " + err.Error()) + } else { + if err := s.client.Del(keys...).Err(); err != nil { + mlog.Error("Redis encountered an error on delete: " + err.Error()) + } + } + + return result } func buildRedisKeyForRoleName(roleName string) string { diff --git a/store/redis_supplier_schemes.go b/store/redis_supplier_schemes.go index ae33361486..49a6c348fb 100644 --- a/store/redis_supplier_schemes.go +++ b/store/redis_supplier_schemes.go @@ -10,31 +10,37 @@ import ( ) func (s *RedisSupplier) SchemeSave(ctx context.Context, scheme *model.Scheme, hints ...LayeredStoreHint) *LayeredStoreSupplierResult { + result := s.Next().SchemeSave(ctx, scheme, hints...) // TODO: Redis caching. - return s.Next().SchemeSave(ctx, scheme, hints...) + return result } func (s *RedisSupplier) SchemeGet(ctx context.Context, schemeId string, hints ...LayeredStoreHint) *LayeredStoreSupplierResult { + result := s.Next().SchemeGet(ctx, schemeId, hints...) // TODO: Redis caching. - return s.Next().SchemeGet(ctx, schemeId, hints...) + return result } func (s *RedisSupplier) SchemeGetByName(ctx context.Context, schemeName string, hints ...LayeredStoreHint) *LayeredStoreSupplierResult { + result := s.Next().SchemeGetByName(ctx, schemeName, hints...) // TODO: Redis caching. - return s.Next().SchemeGetByName(ctx, schemeName, hints...) + return result } func (s *RedisSupplier) SchemeDelete(ctx context.Context, schemeId string, hints ...LayeredStoreHint) *LayeredStoreSupplierResult { + result := s.Next().SchemeDelete(ctx, schemeId, hints...) // TODO: Redis caching. - return s.Next().SchemeDelete(ctx, schemeId, hints...) + return result } func (s *RedisSupplier) SchemeGetAllPage(ctx context.Context, scope string, offset int, limit int, hints ...LayeredStoreHint) *LayeredStoreSupplierResult { + result := s.Next().SchemeGetAllPage(ctx, scope, offset, limit, hints...) // TODO: Redis caching. - return s.Next().SchemeGetAllPage(ctx, scope, offset, limit, hints...) + return result } func (s *RedisSupplier) SchemePermanentDeleteAll(ctx context.Context, hints ...LayeredStoreHint) *LayeredStoreSupplierResult { + result := s.Next().SchemePermanentDeleteAll(ctx, hints...) // TODO: Redis caching. - return s.Next().SchemePermanentDeleteAll(ctx, hints...) + return result } diff --git a/store/sqlstore/emoji_store.go b/store/sqlstore/emoji_store.go index 971cafb6a8..b480c21649 100644 --- a/store/sqlstore/emoji_store.go +++ b/store/sqlstore/emoji_store.go @@ -5,6 +5,7 @@ package sqlstore import ( "database/sql" + "fmt" "net/http" "github.com/mattermost/mattermost-server/einterfaces" @@ -128,6 +129,27 @@ func (es SqlEmojiStore) GetByName(name string) store.StoreChannel { }) } +func (es SqlEmojiStore) GetMultipleByName(names []string) store.StoreChannel { + return store.Do(func(result *store.StoreResult) { + keys, params := MapStringsToQueryParams(names, "Emoji") + + var emojis []*model.Emoji + + if _, err := es.GetReplica().Select(&emojis, + `SELECT + * + FROM + Emoji + WHERE + Name IN `+keys+` + AND DeleteAt = 0`, params); err != nil { + result.Err = model.NewAppError("SqlEmojiStore.GetByName", "store.sql_emoji.get_by_name.app_error", nil, fmt.Sprintf("names=%v, %v", names, err.Error()), http.StatusInternalServerError) + } else { + result.Data = emojis + } + }) +} + func (es SqlEmojiStore) GetList(offset, limit int, sort string) store.StoreChannel { return store.Do(func(result *store.StoreResult) { var emoji []*model.Emoji @@ -151,7 +173,7 @@ func (es SqlEmojiStore) GetList(offset, limit int, sort string) store.StoreChann func (es SqlEmojiStore) Delete(id string, time int64) store.StoreChannel { return store.Do(func(result *store.StoreResult) { if sqlResult, err := es.GetMaster().Exec( - `Update + `UPDATE Emoji SET DeleteAt = :DeleteAt, diff --git a/store/sqlstore/post_store.go b/store/sqlstore/post_store.go index bc85b260ec..4317d968b0 100644 --- a/store/sqlstore/post_store.go +++ b/store/sqlstore/post_store.go @@ -4,7 +4,6 @@ package sqlstore import ( - "bytes" "fmt" "net/http" "regexp" @@ -1144,19 +1143,9 @@ func (s *SqlPostStore) GetPostsCreatedAt(channelId string, time int64) store.Sto func (s *SqlPostStore) GetPostsByIds(postIds []string) store.StoreChannel { return store.Do(func(result *store.StoreResult) { - keys := bytes.Buffer{} - params := make(map[string]interface{}) - for i, postId := range postIds { - if keys.Len() > 0 { - keys.WriteString(",") - } + keys, params := MapStringsToQueryParams(postIds, "Post") - key := "Post" + strconv.Itoa(i) - keys.WriteString(":" + key) - params[key] = postId - } - - query := `SELECT * FROM Posts WHERE Id in (` + keys.String() + `) ORDER BY CreateAt DESC` + query := `SELECT * FROM Posts WHERE Id IN ` + keys + ` ORDER BY CreateAt DESC` var posts []*model.Post _, err := s.GetReplica().Select(&posts, query, params) diff --git a/store/sqlstore/supplier_reactions.go b/store/sqlstore/supplier_reactions.go index 5a9c9302a7..9f3d2bbbbc 100644 --- a/store/sqlstore/supplier_reactions.go +++ b/store/sqlstore/supplier_reactions.go @@ -192,22 +192,18 @@ func deleteReactionAndUpdatePost(transaction *gorp.Transaction, reaction *model. } const ( - // Set HasReactions = true if and only if the post has reactions, update UpdateAt only if HasReactions changes UPDATE_POST_HAS_REACTIONS_ON_DELETE_QUERY = `UPDATE Posts SET - UpdateAt = (CASE - WHEN HasReactions != (SELECT count(0) > 0 FROM Reactions WHERE PostId = :PostId) THEN :UpdateAt - ELSE UpdateAt - END), + UpdateAt = :UpdateAt, HasReactions = (SELECT count(0) > 0 FROM Reactions WHERE PostId = :PostId) WHERE Id = :PostId` ) func updatePostForReactionsOnDelete(transaction *gorp.Transaction, postId string) error { - _, err := transaction.Exec(UPDATE_POST_HAS_REACTIONS_ON_DELETE_QUERY, map[string]interface{}{"PostId": postId, "UpdateAt": model.GetMillis()}) - + updateAt := model.GetMillis() + _, err := transaction.Exec(UPDATE_POST_HAS_REACTIONS_ON_DELETE_QUERY, map[string]interface{}{"PostId": postId, "UpdateAt": updateAt}) return err } @@ -219,7 +215,7 @@ func updatePostForReactionsOnInsert(transaction *gorp.Transaction, postId string HasReactions = True, UpdateAt = :UpdateAt WHERE - Id = :PostId AND HasReactions = False`, + Id = :PostId`, map[string]interface{}{"PostId": postId, "UpdateAt": model.GetMillis()}) return err diff --git a/store/sqlstore/utils.go b/store/sqlstore/utils.go new file mode 100644 index 0000000000..61eac9758b --- /dev/null +++ b/store/sqlstore/utils.go @@ -0,0 +1,28 @@ +// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package sqlstore + +import ( + "bytes" + "fmt" + "strconv" +) + +// Converts a list of strings into a list of query parameters and a named parameter map that can +// be used as part of a SQL query. +func MapStringsToQueryParams(list []string, paramPrefix string) (string, map[string]interface{}) { + keys := bytes.Buffer{} + params := make(map[string]interface{}) + for i, entry := range list { + if keys.Len() > 0 { + keys.WriteString(",") + } + + key := paramPrefix + strconv.Itoa(i) + keys.WriteString(":" + key) + params[key] = entry + } + + return fmt.Sprintf("(%v)", keys.String()), params +} diff --git a/store/sqlstore/utils_test.go b/store/sqlstore/utils_test.go new file mode 100644 index 0000000000..2adf4dc10f --- /dev/null +++ b/store/sqlstore/utils_test.go @@ -0,0 +1,32 @@ +package sqlstore + +import ( + "testing" +) + +func TestMapStringsToQueryParams(t *testing.T) { + t.Run("one item", func(t *testing.T) { + input := []string{"apple"} + + keys, params := MapStringsToQueryParams(input, "Fruit") + + if len(params) != 1 || params["Fruit0"] != "apple" { + t.Fatal("returned incorrect params", params) + } else if keys != "(:Fruit0)" { + t.Fatal("returned incorrect query", keys) + } + }) + + t.Run("multiple items", func(t *testing.T) { + input := []string{"carrot", "tomato", "potato"} + + keys, params := MapStringsToQueryParams(input, "Vegetable") + + if len(params) != 3 || params["Vegetable0"] != "carrot" || + params["Vegetable1"] != "tomato" || params["Vegetable2"] != "potato" { + t.Fatal("returned incorrect params", params) + } else if keys != "(:Vegetable0,:Vegetable1,:Vegetable2)" { + t.Fatal("returned incorrect query", keys) + } + }) +} diff --git a/store/store.go b/store/store.go index eefaa4649b..3ba6b5715b 100644 --- a/store/store.go +++ b/store/store.go @@ -427,6 +427,7 @@ type EmojiStore interface { Save(emoji *model.Emoji) StoreChannel Get(id string, allowFromCache bool) StoreChannel GetByName(name string) StoreChannel + GetMultipleByName(names []string) StoreChannel GetList(offset, limit int, sort string) StoreChannel Delete(id string, time int64) StoreChannel Search(name string, prefixOnly bool, limit int) StoreChannel diff --git a/store/storetest/emoji_store.go b/store/storetest/emoji_store.go index 9e4dbaa6eb..087bdbecbb 100644 --- a/store/storetest/emoji_store.go +++ b/store/storetest/emoji_store.go @@ -17,6 +17,7 @@ func TestEmojiStore(t *testing.T, ss store.Store) { t.Run("EmojiSaveDelete", func(t *testing.T) { testEmojiSaveDelete(t, ss) }) t.Run("EmojiGet", func(t *testing.T) { testEmojiGet(t, ss) }) t.Run("EmojiGetByName", func(t *testing.T) { testEmojiGetByName(t, ss) }) + t.Run("EmojiGetMultipleByName", func(t *testing.T) { testEmojiGetMultipleByName(t, ss) }) t.Run("EmojiGetList", func(t *testing.T) { testEmojiGetList(t, ss) }) t.Run("EmojiSearch", func(t *testing.T) { testEmojiSearch(t, ss) }) } @@ -132,6 +133,64 @@ func testEmojiGetByName(t *testing.T, ss store.Store) { } } +func testEmojiGetMultipleByName(t *testing.T, ss store.Store) { + emojis := []model.Emoji{ + { + CreatorId: model.NewId(), + Name: model.NewId(), + }, + { + CreatorId: model.NewId(), + Name: model.NewId(), + }, + { + CreatorId: model.NewId(), + Name: model.NewId(), + }, + } + + for i, emoji := range emojis { + emojis[i] = *store.Must(ss.Emoji().Save(&emoji)).(*model.Emoji) + } + defer func() { + for _, emoji := range emojis { + store.Must(ss.Emoji().Delete(emoji.Id, time.Now().Unix())) + } + }() + + t.Run("one emoji", func(t *testing.T) { + if result := <-ss.Emoji().GetMultipleByName([]string{emojis[0].Name}); result.Err != nil { + t.Fatal("could not get emoji", result.Err) + } else if received := result.Data.([]*model.Emoji); len(received) != 1 || *received[0] != emojis[0] { + t.Fatal("got incorrect emoji") + } + }) + + t.Run("multiple emojis", func(t *testing.T) { + if result := <-ss.Emoji().GetMultipleByName([]string{emojis[0].Name, emojis[1].Name, emojis[2].Name}); result.Err != nil { + t.Fatal("could not get emojis", result.Err) + } else if received := result.Data.([]*model.Emoji); len(received) != 3 { + t.Fatal("got incorrect emojis") + } + }) + + t.Run("one nonexistent emoji", func(t *testing.T) { + if result := <-ss.Emoji().GetMultipleByName([]string{"ab"}); result.Err != nil { + t.Fatal("could not get emoji", result.Err) + } else if received := result.Data.([]*model.Emoji); len(received) != 0 { + t.Fatal("got incorrect emoji") + } + }) + + t.Run("multiple emojis with nonexistent names", func(t *testing.T) { + if result := <-ss.Emoji().GetMultipleByName([]string{emojis[0].Name, emojis[1].Name, emojis[2].Name, "abcd", "1234"}); result.Err != nil { + t.Fatal("could not get emojis", result.Err) + } else if received := result.Data.([]*model.Emoji); len(received) != 3 { + t.Fatal("got incorrect emojis") + } + }) +} + func testEmojiGetList(t *testing.T, ss store.Store) { emojis := []model.Emoji{ { diff --git a/store/storetest/mocks/EmojiStore.go b/store/storetest/mocks/EmojiStore.go index b1f0a3217f..80b12cfe69 100644 --- a/store/storetest/mocks/EmojiStore.go +++ b/store/storetest/mocks/EmojiStore.go @@ -77,6 +77,22 @@ func (_m *EmojiStore) GetList(offset int, limit int, sort string) store.StoreCha return r0 } +// GetMultipleByName provides a mock function with given fields: names +func (_m *EmojiStore) GetMultipleByName(names []string) store.StoreChannel { + ret := _m.Called(names) + + var r0 store.StoreChannel + if rf, ok := ret.Get(0).(func([]string) store.StoreChannel); ok { + r0 = rf(names) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(store.StoreChannel) + } + } + + return r0 +} + // Save provides a mock function with given fields: emoji func (_m *EmojiStore) Save(emoji *model.Emoji) store.StoreChannel { ret := _m.Called(emoji) diff --git a/store/storetest/reaction_store.go b/store/storetest/reaction_store.go index 1b51284261..303f0d07e5 100644 --- a/store/storetest/reaction_store.go +++ b/store/storetest/reaction_store.go @@ -61,8 +61,8 @@ func testReactionSave(t *testing.T, ss store.Store) { t.Fatal(result.Err) } - if postList := store.Must(ss.Post().Get(reaction2.PostId)).(*model.PostList); postList.Posts[post.Id].UpdateAt != secondUpdateAt { - t.Fatal("shouldn't mark as updated when HasReactions hasn't changed") + if postList := store.Must(ss.Post().Get(reaction2.PostId)).(*model.PostList); postList.Posts[post.Id].UpdateAt == secondUpdateAt { + t.Fatal("should've marked post as updated even if HasReactions doesn't change") } // different post @@ -123,7 +123,7 @@ func testReactionDelete(t *testing.T, ss store.Store) { if postList := store.Must(ss.Post().Get(post.Id)).(*model.PostList); postList.Posts[post.Id].HasReactions { t.Fatal("should've set HasReactions = false on post") } else if postList.Posts[post.Id].UpdateAt == firstUpdateAt { - t.Fatal("shouldn't mark as updated when HasReactions has changed after deleting reactions") + t.Fatal("should mark post as updated after deleting reactions") } }