diff --git a/api4/reaction.go b/api4/reaction.go index 76a0f0306f..882cb46a8c 100644 --- a/api4/reaction.go +++ b/api4/reaction.go @@ -13,6 +13,7 @@ func (api *API) InitReaction() { api.BaseRoutes.Reactions.Handle("", api.ApiSessionRequired(saveReaction)).Methods("POST") api.BaseRoutes.Post.Handle("/reactions", api.ApiSessionRequired(getReactions)).Methods("GET") api.BaseRoutes.ReactionByNameForPostForUser.Handle("", api.ApiSessionRequired(deleteReaction)).Methods("DELETE") + api.BaseRoutes.Posts.Handle("/ids/reactions", api.ApiSessionRequired(getBulkReactions)).Methods("POST") } func saveReaction(c *Context, w http.ResponseWriter, r *http.Request) { @@ -106,3 +107,20 @@ func deleteReaction(c *Context, w http.ResponseWriter, r *http.Request) { ReturnStatusOK(w) } + +func getBulkReactions(c *Context, w http.ResponseWriter, r *http.Request) { + postIds := model.ArrayFromJson(r.Body) + for _, postId := range postIds { + if !c.App.SessionHasPermissionToChannelByPost(c.App.Session, postId, model.PERMISSION_READ_CHANNEL) { + c.SetPermissionError(model.PERMISSION_READ_CHANNEL) + return + } + } + reactions, err := c.App.GetBulkReactionsForPosts(postIds) + if err != nil { + c.Err = err + return + } + + w.Write([]byte(model.MapPostIdToReactionsToJson(reactions))) +} diff --git a/api4/reaction_test.go b/api4/reaction_test.go index 7a0ade4938..ab3f5a2408 100644 --- a/api4/reaction_test.go +++ b/api4/reaction_test.go @@ -567,3 +567,85 @@ func TestDeleteReaction(t *testing.T) { } }) } + +func TestGetBulkReactions(t *testing.T) { + th := Setup().InitBasic() + defer th.TearDown() + Client := th.Client + userId := th.BasicUser.Id + user2Id := th.BasicUser2.Id + post1 := &model.Post{UserId: userId, ChannelId: th.BasicChannel.Id, Message: "zz" + model.NewId() + "a"} + post2 := &model.Post{UserId: userId, ChannelId: th.BasicChannel.Id, Message: "zz" + model.NewId() + "a"} + post3 := &model.Post{UserId: userId, ChannelId: th.BasicChannel.Id, Message: "zz" + model.NewId() + "a"} + + post4 := &model.Post{UserId: user2Id, ChannelId: th.BasicChannel.Id, Message: "zz" + model.NewId() + "a"} + post5 := &model.Post{UserId: user2Id, ChannelId: th.BasicChannel.Id, Message: "zz" + model.NewId() + "a"} + + post1, _ = Client.CreatePost(post1) + post2, _ = Client.CreatePost(post2) + post3, _ = Client.CreatePost(post3) + post4, _ = Client.CreatePost(post4) + post5, _ = Client.CreatePost(post5) + + expectedPostIdsReactionsMap := make(map[string][]*model.Reaction) + expectedPostIdsReactionsMap[post1.Id] = []*model.Reaction{} + expectedPostIdsReactionsMap[post2.Id] = []*model.Reaction{} + expectedPostIdsReactionsMap[post3.Id] = []*model.Reaction{} + expectedPostIdsReactionsMap[post5.Id] = []*model.Reaction{} + + userReactions := []*model.Reaction{ + { + UserId: userId, + PostId: post1.Id, + EmojiName: "happy", + }, + { + UserId: userId, + PostId: post1.Id, + EmojiName: "sad", + }, + { + UserId: userId, + PostId: post2.Id, + EmojiName: "smile", + }, + { + UserId: user2Id, + PostId: post4.Id, + EmojiName: "smile", + }, + } + + for _, userReaction := range userReactions { + reactions := expectedPostIdsReactionsMap[userReaction.PostId] + if result := <-th.App.Srv.Store.Reaction().Save(userReaction); result.Err != nil { + t.Fatal(result.Err) + } else { + reactions = append(reactions, result.Data.(*model.Reaction)) + + } + expectedPostIdsReactionsMap[userReaction.PostId] = reactions + } + + postIds := []string{post1.Id, post2.Id, post3.Id, post4.Id, post5.Id} + + t.Run("get-reactions", func(t *testing.T) { + postIdsReactionsMap, resp := Client.GetBulkReactions(postIds) + CheckNoError(t, resp) + + assert.ElementsMatch(t, expectedPostIdsReactionsMap[post1.Id], postIdsReactionsMap[post1.Id]) + assert.ElementsMatch(t, expectedPostIdsReactionsMap[post2.Id], postIdsReactionsMap[post2.Id]) + assert.ElementsMatch(t, expectedPostIdsReactionsMap[post3.Id], postIdsReactionsMap[post3.Id]) + assert.ElementsMatch(t, expectedPostIdsReactionsMap[post4.Id], postIdsReactionsMap[post4.Id]) + assert.ElementsMatch(t, expectedPostIdsReactionsMap[post5.Id], postIdsReactionsMap[post5.Id]) + assert.Equal(t, expectedPostIdsReactionsMap, postIdsReactionsMap) + + }) + + t.Run("get-reactions-as-anonymous-user", func(t *testing.T) { + Client.Logout() + + _, resp := Client.GetBulkReactions(postIds) + CheckUnauthorizedStatus(t, resp) + }) +} diff --git a/app/reaction.go b/app/reaction.go index 19dc02788c..5b4e916222 100644 --- a/app/reaction.go +++ b/app/reaction.go @@ -60,6 +60,35 @@ func (a *App) GetReactionsForPost(postId string) ([]*model.Reaction, *model.AppE return result.Data.([]*model.Reaction), nil } +func (a *App) GetBulkReactionsForPosts(postIds []string) (map[string][]*model.Reaction, *model.AppError) { + reactions := make(map[string][]*model.Reaction) + + result := <-a.Srv.Store.Reaction().BulkGetForPosts(postIds) + if result.Err != nil { + return nil, result.Err + } + + allReactions := result.Data.([]*model.Reaction) + for _, reaction := range allReactions { + reactionsForPost := reactions[reaction.PostId] + reactionsForPost = append(reactionsForPost, reaction) + + reactions[reaction.PostId] = reactionsForPost + } + + reactions = populateEmptyReactions(postIds, reactions) + return reactions, nil +} + +func populateEmptyReactions(postIds []string, reactions map[string][]*model.Reaction) map[string][]*model.Reaction { + for _, postId := range postIds { + if _, present := reactions[postId]; !present { + reactions[postId] = []*model.Reaction{} + } + } + return reactions +} + func (a *App) DeleteReactionForPost(reaction *model.Reaction) *model.AppError { post, err := a.GetSinglePost(reaction.PostId) if err != nil { diff --git a/i18n/en.json b/i18n/en.json index 1fea346163..7275910878 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -5842,6 +5842,10 @@ "id": "store.sql_reaction.get_for_post.app_error", "translation": "Unable to get reactions for post" }, + { + "id": "store.sql_reaction.bulk_get_for_post_ids.app_error", + "translation": "Unable to get reactions for post" + }, { "id": "store.sql_reaction.permanent_delete_batch.app_error", "translation": "We encountered an error permanently deleting the batch of reactions" diff --git a/model/client4.go b/model/client4.go index da8c14ccd1..68a5452410 100644 --- a/model/client4.go +++ b/model/client4.go @@ -3592,6 +3592,16 @@ func (c *Client4) DeleteReaction(reaction *Reaction) (bool, *Response) { return CheckStatusOK(r), BuildResponse(r) } +// FetchBulkReactions returns a map of postIds and corresponding reactions +func (c *Client4) GetBulkReactions(postIds []string) (map[string][]*Reaction, *Response) { + r, err := c.DoApiPost(c.GetPostsRoute()+"/ids/reactions", ArrayToJson(postIds)) + if err != nil { + return nil, BuildErrorResponse(r, err) + } + defer closeBody(r) + return MapPostIdToReactionsFromJson(r.Body), BuildResponse(r) +} + // Timezone Section // GetSupportedTimezone returns a page of supported timezones on the system. diff --git a/model/reaction.go b/model/reaction.go index c1b9c499a8..550918a33a 100644 --- a/model/reaction.go +++ b/model/reaction.go @@ -37,6 +37,22 @@ func ReactionsToJson(o []*Reaction) string { return string(b) } +func MapPostIdToReactionsToJson(o map[string][]*Reaction) string { + b, _ := json.Marshal(o) + return string(b) +} + +func MapPostIdToReactionsFromJson(data io.Reader) map[string][]*Reaction { + decoder := json.NewDecoder(data) + + var objmap map[string][]*Reaction + if err := decoder.Decode(&objmap); err != nil { + return make(map[string][]*Reaction) + } else { + return objmap + } +} + func ReactionsFromJson(data io.Reader) []*Reaction { var o []*Reaction diff --git a/store/layered_store.go b/store/layered_store.go index 639e60001f..4ea81f6ede 100644 --- a/store/layered_store.go +++ b/store/layered_store.go @@ -236,6 +236,12 @@ func (s *LayeredReactionStore) GetForPost(postId string, allowFromCache bool) St }) } +func (s *LayeredReactionStore) BulkGetForPosts(postIds []string) StoreChannel { + return s.RunQuery(func(supplier LayeredStoreSupplier) *LayeredStoreSupplierResult { + return supplier.ReactionsBulkGetForPosts(s.TmpContext, postIds) + }) +} + func (s *LayeredReactionStore) DeleteAllWithEmojiName(emojiName string) StoreChannel { return s.RunQuery(func(supplier LayeredStoreSupplier) *LayeredStoreSupplierResult { return supplier.ReactionDeleteAllWithEmojiName(s.TmpContext, emojiName) diff --git a/store/layered_store_supplier.go b/store/layered_store_supplier.go index 9718590076..45ec00068f 100644 --- a/store/layered_store_supplier.go +++ b/store/layered_store_supplier.go @@ -29,6 +29,7 @@ type LayeredStoreSupplier interface { ReactionGetForPost(ctx context.Context, postId string, hints ...LayeredStoreHint) *LayeredStoreSupplierResult ReactionDeleteAllWithEmojiName(ctx context.Context, emojiName string, hints ...LayeredStoreHint) *LayeredStoreSupplierResult ReactionPermanentDeleteBatch(ctx context.Context, endTime int64, limit int64, hints ...LayeredStoreHint) *LayeredStoreSupplierResult + ReactionsBulkGetForPosts(ctx context.Context, postIds []string, hints ...LayeredStoreHint) *LayeredStoreSupplierResult // Roles RoleSave(ctx context.Context, role *model.Role, hints ...LayeredStoreHint) *LayeredStoreSupplierResult diff --git a/store/local_cache_supplier_reactions.go b/store/local_cache_supplier_reactions.go index dd588e7c43..260d20e9af 100644 --- a/store/local_cache_supplier_reactions.go +++ b/store/local_cache_supplier_reactions.go @@ -51,3 +51,8 @@ func (s *LocalCacheSupplier) ReactionPermanentDeleteBatch(ctx context.Context, e // expire from the cache in due course. return s.Next().ReactionPermanentDeleteBatch(ctx, endTime, limit) } + +func (s *LocalCacheSupplier) ReactionsBulkGetForPosts(ctx context.Context, postIds []string, hints ...LayeredStoreHint) *LayeredStoreSupplierResult { + // Ignoring this. + return s.Next().ReactionsBulkGetForPosts(ctx, postIds, hints...) +} diff --git a/store/redis_supplier_reactions.go b/store/redis_supplier_reactions.go index ec9a4b4e03..eccd87ed8c 100644 --- a/store/redis_supplier_reactions.go +++ b/store/redis_supplier_reactions.go @@ -54,3 +54,8 @@ func (s *RedisSupplier) ReactionPermanentDeleteBatch(ctx context.Context, endTim // Ignoring this. It's probably OK to have the emoji slowly expire from Redis. return s.Next().ReactionPermanentDeleteBatch(ctx, endTime, limit, hints...) } + +func (s *RedisSupplier) ReactionsBulkGetForPosts(ctx context.Context, postIds []string, hints ...LayeredStoreHint) *LayeredStoreSupplierResult { + // Ignoring this. + return s.Next().ReactionsBulkGetForPosts(ctx, postIds, hints...) +} diff --git a/store/sqlstore/supplier_reactions.go b/store/sqlstore/supplier_reactions.go index 9f3d2bbbbc..9ec299adba 100644 --- a/store/sqlstore/supplier_reactions.go +++ b/store/sqlstore/supplier_reactions.go @@ -103,6 +103,28 @@ func (s *SqlSupplier) ReactionGetForPost(ctx context.Context, postId string, hin return result } +func (s *SqlSupplier) ReactionsBulkGetForPosts(ctx context.Context, postIds []string, hints ...store.LayeredStoreHint) *store.LayeredStoreSupplierResult { + result := store.NewSupplierResult() + + keys, params := MapStringsToQueryParams(postIds, "postId") + var reactions []*model.Reaction + + if _, err := s.GetReplica().Select(&reactions, `SELECT + * + FROM + Reactions + WHERE + PostId IN `+keys+` + ORDER BY + CreateAt`, params); err != nil { + result.Err = model.NewAppError("SqlReactionStore.GetForPost", "store.sql_reaction.bulk_get_for_post_ids.app_error", nil, "", http.StatusInternalServerError) + } else { + result.Data = reactions + } + + return result +} + func (s *SqlSupplier) ReactionDeleteAllWithEmojiName(ctx context.Context, emojiName string, hints ...store.LayeredStoreHint) *store.LayeredStoreSupplierResult { result := store.NewSupplierResult() diff --git a/store/store.go b/store/store.go index 708352bd48..7eb080cd6b 100644 --- a/store/store.go +++ b/store/store.go @@ -467,6 +467,7 @@ type ReactionStore interface { GetForPost(postId string, allowFromCache bool) StoreChannel DeleteAllWithEmojiName(emojiName string) StoreChannel PermanentDeleteBatch(endTime int64, limit int64) StoreChannel + BulkGetForPosts(postIds []string) StoreChannel } type JobStore interface { diff --git a/store/storetest/mocks/LayeredStoreDatabaseLayer.go b/store/storetest/mocks/LayeredStoreDatabaseLayer.go index 0531d7c377..eb0881dd85 100644 --- a/store/storetest/mocks/LayeredStoreDatabaseLayer.go +++ b/store/storetest/mocks/LayeredStoreDatabaseLayer.go @@ -421,6 +421,29 @@ func (_m *LayeredStoreDatabaseLayer) ReactionSave(ctx context.Context, reaction return r0 } +// ReactionsBulkGetForPosts provides a mock function with given fields: ctx, postIds, hints +func (_m *LayeredStoreDatabaseLayer) ReactionsBulkGetForPosts(ctx context.Context, postIds []string, hints ...store.LayeredStoreHint) *store.LayeredStoreSupplierResult { + _va := make([]interface{}, len(hints)) + for _i := range hints { + _va[_i] = hints[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, postIds) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *store.LayeredStoreSupplierResult + if rf, ok := ret.Get(0).(func(context.Context, []string, ...store.LayeredStoreHint) *store.LayeredStoreSupplierResult); ok { + r0 = rf(ctx, postIds, hints...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*store.LayeredStoreSupplierResult) + } + } + + return r0 +} + // Role provides a mock function with given fields: func (_m *LayeredStoreDatabaseLayer) Role() store.RoleStore { ret := _m.Called() diff --git a/store/storetest/mocks/LayeredStoreSupplier.go b/store/storetest/mocks/LayeredStoreSupplier.go index ddd0abf589..4b3da6efc7 100644 --- a/store/storetest/mocks/LayeredStoreSupplier.go +++ b/store/storetest/mocks/LayeredStoreSupplier.go @@ -145,6 +145,29 @@ func (_m *LayeredStoreSupplier) ReactionSave(ctx context.Context, reaction *mode return r0 } +// ReactionsBulkGetForPosts provides a mock function with given fields: ctx, postIds, hints +func (_m *LayeredStoreSupplier) ReactionsBulkGetForPosts(ctx context.Context, postIds []string, hints ...store.LayeredStoreHint) *store.LayeredStoreSupplierResult { + _va := make([]interface{}, len(hints)) + for _i := range hints { + _va[_i] = hints[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, postIds) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *store.LayeredStoreSupplierResult + if rf, ok := ret.Get(0).(func(context.Context, []string, ...store.LayeredStoreHint) *store.LayeredStoreSupplierResult); ok { + r0 = rf(ctx, postIds, hints...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*store.LayeredStoreSupplierResult) + } + } + + return r0 +} + // RoleDelete provides a mock function with given fields: ctx, roldId, hints func (_m *LayeredStoreSupplier) RoleDelete(ctx context.Context, roldId string, hints ...store.LayeredStoreHint) *store.LayeredStoreSupplierResult { _va := make([]interface{}, len(hints)) diff --git a/store/storetest/mocks/ReactionStore.go b/store/storetest/mocks/ReactionStore.go index b3e81a83b7..0f72a98b37 100644 --- a/store/storetest/mocks/ReactionStore.go +++ b/store/storetest/mocks/ReactionStore.go @@ -13,6 +13,22 @@ type ReactionStore struct { mock.Mock } +// BulkGetForPosts provides a mock function with given fields: postIds +func (_m *ReactionStore) BulkGetForPosts(postIds []string) store.StoreChannel { + ret := _m.Called(postIds) + + var r0 store.StoreChannel + if rf, ok := ret.Get(0).(func([]string) store.StoreChannel); ok { + r0 = rf(postIds) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(store.StoreChannel) + } + } + + return r0 +} + // Delete provides a mock function with given fields: reaction func (_m *ReactionStore) Delete(reaction *model.Reaction) store.StoreChannel { ret := _m.Called(reaction) diff --git a/store/storetest/mocks/SqlStore.go b/store/storetest/mocks/SqlStore.go index 6f76c8a030..4ccc53dd79 100644 --- a/store/storetest/mocks/SqlStore.go +++ b/store/storetest/mocks/SqlStore.go @@ -14,6 +14,20 @@ type SqlStore struct { mock.Mock } +// AlterColumnDefaultIfExists provides a mock function with given fields: tableName, columnName, mySqlColDefault, postgresColDefault +func (_m *SqlStore) AlterColumnDefaultIfExists(tableName string, columnName string, mySqlColDefault *string, postgresColDefault *string) bool { + ret := _m.Called(tableName, columnName, mySqlColDefault, postgresColDefault) + + var r0 bool + if rf, ok := ret.Get(0).(func(string, string, *string, *string) bool); ok { + r0 = rf(tableName, columnName, mySqlColDefault, postgresColDefault) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + // AlterColumnTypeIfExists provides a mock function with given fields: tableName, columnName, mySqlColType, postgresColType func (_m *SqlStore) AlterColumnTypeIfExists(tableName string, columnName string, mySqlColType string, postgresColType string) bool { ret := _m.Called(tableName, columnName, mySqlColType, postgresColType) diff --git a/store/storetest/reaction_store.go b/store/storetest/reaction_store.go index 303f0d07e5..bba82bee63 100644 --- a/store/storetest/reaction_store.go +++ b/store/storetest/reaction_store.go @@ -16,6 +16,7 @@ func TestReactionStore(t *testing.T, ss store.Store) { t.Run("ReactionGetForPost", func(t *testing.T) { testReactionGetForPost(t, ss) }) t.Run("ReactionDeleteAllWithEmojiName", func(t *testing.T) { testReactionDeleteAllWithEmojiName(t, ss) }) t.Run("PermanentDeleteBatch", func(t *testing.T) { testReactionStorePermanentDeleteBatch(t, ss) }) + t.Run("ReactionBulkGetForPosts", func(t *testing.T) { testReactionBulkGetForPosts(t, ss) }) } func testReactionSave(t *testing.T, ss store.Store) { @@ -348,3 +349,69 @@ func testReactionStorePermanentDeleteBatch(t *testing.T, ss store.Store) { t.Fatalf("expected 1 reaction. Got: %v", len(returned)) } } + +func testReactionBulkGetForPosts(t *testing.T, ss store.Store) { + postId := model.NewId() + post2Id := model.NewId() + post3Id := model.NewId() + post4Id := model.NewId() + + userId := model.NewId() + + reactions := []*model.Reaction{ + { + UserId: userId, + PostId: postId, + EmojiName: "smile", + }, + { + UserId: model.NewId(), + PostId: post2Id, + EmojiName: "smile", + }, + { + UserId: userId, + PostId: post3Id, + EmojiName: "sad", + }, + { + UserId: userId, + PostId: postId, + EmojiName: "angry", + }, + { + UserId: userId, + PostId: post2Id, + EmojiName: "angry", + }, + { + UserId: userId, + PostId: post4Id, + EmojiName: "angry", + }, + } + + for _, reaction := range reactions { + store.Must(ss.Reaction().Save(reaction)) + } + + postIds := []string{postId, post2Id, post3Id} + if result := <-ss.Reaction().BulkGetForPosts(postIds); result.Err != nil { + t.Fatal(result.Err) + } else if returned := result.Data.([]*model.Reaction); len(returned) != 5 { + t.Fatal("should've returned 5 reactions") + } else { + post4IdFound := false + for _, reaction := range returned { + if reaction.PostId == post4Id { + post4IdFound = true + break + } + } + + if post4IdFound { + t.Fatal("Wrong reaction returned") + } + } + +}