diff --git a/api4/commands_test.go b/api4/commands_test.go index 9167e4e60c..0e2ef926d4 100644 --- a/api4/commands_test.go +++ b/api4/commands_test.go @@ -30,7 +30,7 @@ func TestEchoCommand(t *testing.T) { time.Sleep(100 * time.Millisecond) - p1 := Client.Must(Client.GetPostsForChannel(channel1.Id, 0, 2, "")).(*model.PostList) + p1 := Client.Must(Client.GetPostsForChannel(channel1.Id, 0, 2, "", false)).(*model.PostList) require.Len(t, p1.Order, 2, "Echo command failed to send") } @@ -302,7 +302,7 @@ func TestMeCommand(t *testing.T) { time.Sleep(100 * time.Millisecond) - p1 := Client.Must(Client.GetPostsForChannel(channel.Id, 0, 2, "")).(*model.PostList) + p1 := Client.Must(Client.GetPostsForChannel(channel.Id, 0, 2, "", false)).(*model.PostList) require.Len(t, p1.Order, 2, "Command failed to send") pt := p1.Posts[p1.Order[0]].Type @@ -391,7 +391,7 @@ func TestShrugCommand(t *testing.T) { time.Sleep(100 * time.Millisecond) - p1 := Client.Must(Client.GetPostsForChannel(channel.Id, 0, 2, "")).(*model.PostList) + p1 := Client.Must(Client.GetPostsForChannel(channel.Id, 0, 2, "", false)).(*model.PostList) require.Len(t, p1.Order, 2, "Command failed to send") require.Equal(t, `¯\\\_(ツ)\_/¯`, p1.Posts[p1.Order[0]].Message, "invalid shrug response") } diff --git a/api4/post.go b/api4/post.go index 994fc4c7ef..6996cdbdde 100644 --- a/api4/post.go +++ b/api4/post.go @@ -158,7 +158,8 @@ func getPostsForChannel(c *Context, w http.ResponseWriter, r *http.Request) { } } skipFetchThreads := r.URL.Query().Get("skipFetchThreads") == "true" - + collapsedThreads := r.URL.Query().Get("collapsedThreads") == "true" + collapsedThreadsExtended := r.URL.Query().Get("collapsedThreadsExtended") == "true" channelId := c.Params.ChannelId page := c.Params.Page perPage := c.Params.PerPage @@ -173,31 +174,31 @@ func getPostsForChannel(c *Context, w http.ResponseWriter, r *http.Request) { etag := "" if since > 0 { - list, err = c.App.GetPostsSince(model.GetPostsSinceOptions{ChannelId: channelId, Time: since, SkipFetchThreads: skipFetchThreads}) + list, err = c.App.GetPostsSince(model.GetPostsSinceOptions{ChannelId: channelId, Time: since, SkipFetchThreads: skipFetchThreads, CollapsedThreads: collapsedThreads, CollapsedThreadsExtended: collapsedThreadsExtended}) } else if len(afterPost) > 0 { - etag = c.App.GetPostsEtag(channelId) + etag = c.App.GetPostsEtag(channelId, collapsedThreads) if c.HandleEtag(etag, "Get Posts After", w, r) { return } - list, err = c.App.GetPostsAfterPost(model.GetPostsOptions{ChannelId: channelId, PostId: afterPost, Page: page, PerPage: perPage, SkipFetchThreads: skipFetchThreads}) + list, err = c.App.GetPostsAfterPost(model.GetPostsOptions{ChannelId: channelId, PostId: afterPost, Page: page, PerPage: perPage, SkipFetchThreads: skipFetchThreads, CollapsedThreads: collapsedThreads}) } else if len(beforePost) > 0 { - etag = c.App.GetPostsEtag(channelId) + etag = c.App.GetPostsEtag(channelId, collapsedThreads) if c.HandleEtag(etag, "Get Posts Before", w, r) { return } - list, err = c.App.GetPostsBeforePost(model.GetPostsOptions{ChannelId: channelId, PostId: beforePost, Page: page, PerPage: perPage, SkipFetchThreads: skipFetchThreads}) + list, err = c.App.GetPostsBeforePost(model.GetPostsOptions{ChannelId: channelId, PostId: beforePost, Page: page, PerPage: perPage, SkipFetchThreads: skipFetchThreads, CollapsedThreads: collapsedThreads, CollapsedThreadsExtended: collapsedThreadsExtended}) } else { - etag = c.App.GetPostsEtag(channelId) + etag = c.App.GetPostsEtag(channelId, collapsedThreads) if c.HandleEtag(etag, "Get Posts", w, r) { return } - list, err = c.App.GetPostsPage(model.GetPostsOptions{ChannelId: channelId, Page: page, PerPage: perPage, SkipFetchThreads: skipFetchThreads}) + list, err = c.App.GetPostsPage(model.GetPostsOptions{ChannelId: channelId, Page: page, PerPage: perPage, SkipFetchThreads: skipFetchThreads, CollapsedThreads: collapsedThreads, CollapsedThreadsExtended: collapsedThreadsExtended}) } if err != nil { @@ -239,7 +240,10 @@ func getPostsForChannelAroundLastUnread(c *Context, w http.ResponseWriter, r *ht } skipFetchThreads := r.URL.Query().Get("skipFetchThreads") == "true" - postList, err := c.App.GetPostsForChannelAroundLastUnread(channelId, userId, c.Params.LimitBefore, c.Params.LimitAfter, skipFetchThreads) + collapsedThreads := r.URL.Query().Get("collapsedThreads") == "true" + collapsedThreadsExtended := r.URL.Query().Get("collapsedThreadsExtended") == "true" + + postList, err := c.App.GetPostsForChannelAroundLastUnread(channelId, userId, c.Params.LimitBefore, c.Params.LimitAfter, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) if err != nil { c.Err = err return @@ -247,13 +251,13 @@ func getPostsForChannelAroundLastUnread(c *Context, w http.ResponseWriter, r *ht etag := "" if len(postList.Order) == 0 { - etag = c.App.GetPostsEtag(channelId) + etag = c.App.GetPostsEtag(channelId, collapsedThreads) if c.HandleEtag(etag, "Get Posts", w, r) { return } - postList, err = c.App.GetPostsPage(model.GetPostsOptions{ChannelId: channelId, Page: app.PageDefault, PerPage: c.Params.LimitBefore, SkipFetchThreads: skipFetchThreads}) + postList, err = c.App.GetPostsPage(model.GetPostsOptions{ChannelId: channelId, Page: app.PageDefault, PerPage: c.Params.LimitBefore, SkipFetchThreads: skipFetchThreads, CollapsedThreads: collapsedThreads, CollapsedThreadsExtended: collapsedThreadsExtended}) if err != nil { c.Err = err return @@ -412,7 +416,9 @@ func getPostThread(c *Context, w http.ResponseWriter, r *http.Request) { return } skipFetchThreads := r.URL.Query().Get("skipFetchThreads") == "true" - list, err := c.App.GetPostThread(c.Params.PostId, skipFetchThreads) + collapsedThreads := r.URL.Query().Get("collapsedThreads") == "true" + collapsedThreadsExtended := r.URL.Query().Get("collapsedThreadsExtended") == "true" + list, err := c.App.GetPostThread(c.Params.PostId, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) if err != nil { c.Err = err return diff --git a/api4/post_test.go b/api4/post_test.go index 18547494c7..1e6e4c5cd4 100644 --- a/api4/post_test.go +++ b/api4/post_test.go @@ -389,7 +389,7 @@ func testCreatePostWithOutgoingHook( if commentPostType { time.Sleep(time.Millisecond * 100) - postList, resp := th.SystemAdminClient.GetPostThread(post.Id, "") + postList, resp := th.SystemAdminClient.GetPostThread(post.Id, "", false) CheckNoError(t, resp) require.Equal(t, post.Id, postList.Order[0], "wrong order") @@ -1052,17 +1052,17 @@ func TestGetPostsForChannel(t *testing.T) { post4 := th.CreatePost() th.TestForAllClients(t, func(t *testing.T, c *model.Client4) { - posts, resp := c.GetPostsForChannel(th.BasicChannel.Id, 0, 60, "") + posts, resp := c.GetPostsForChannel(th.BasicChannel.Id, 0, 60, "", false) CheckNoError(t, resp) require.Equal(t, post4.Id, posts.Order[0], "wrong order") require.Equal(t, post3.Id, posts.Order[1], "wrong order") require.Equal(t, post2.Id, posts.Order[2], "wrong order") require.Equal(t, post1.Id, posts.Order[3], "wrong order") - posts, resp = c.GetPostsForChannel(th.BasicChannel.Id, 0, 3, resp.Etag) + posts, resp = c.GetPostsForChannel(th.BasicChannel.Id, 0, 3, resp.Etag, false) CheckEtag(t, posts, resp) - posts, resp = c.GetPostsForChannel(th.BasicChannel.Id, 0, 3, "") + posts, resp = c.GetPostsForChannel(th.BasicChannel.Id, 0, 3, "", false) CheckNoError(t, resp) require.Len(t, posts.Order, 3, "wrong number returned") @@ -1071,11 +1071,11 @@ func TestGetPostsForChannel(t *testing.T) { _, ok = posts.Posts[post1.Id] require.True(t, ok, "missing root post") - posts, resp = c.GetPostsForChannel(th.BasicChannel.Id, 1, 1, "") + posts, resp = c.GetPostsForChannel(th.BasicChannel.Id, 1, 1, "", false) CheckNoError(t, resp) require.Equal(t, post3.Id, posts.Order[0], "wrong order") - posts, resp = c.GetPostsForChannel(th.BasicChannel.Id, 10000, 10000, "") + posts, resp = c.GetPostsForChannel(th.BasicChannel.Id, 10000, 10000, "", false) CheckNoError(t, resp) require.Empty(t, posts.Order, "should be no posts") }) @@ -1083,7 +1083,7 @@ func TestGetPostsForChannel(t *testing.T) { post5 := th.CreatePost() th.TestForAllClients(t, func(t *testing.T, c *model.Client4) { - posts, resp := c.GetPostsSince(th.BasicChannel.Id, since) + posts, resp := c.GetPostsSince(th.BasicChannel.Id, since, false) CheckNoError(t, resp) require.Len(t, posts.Posts, 2, "should return 2 posts") @@ -1105,18 +1105,18 @@ func TestGetPostsForChannel(t *testing.T) { require.True(t, f, "missing post") } - _, resp = c.GetPostsForChannel("", 0, 60, "") + _, resp = c.GetPostsForChannel("", 0, 60, "", false) CheckBadRequestStatus(t, resp) - _, resp = c.GetPostsForChannel("junk", 0, 60, "") + _, resp = c.GetPostsForChannel("junk", 0, 60, "", false) CheckBadRequestStatus(t, resp) }) - _, resp := Client.GetPostsForChannel(model.NewId(), 0, 60, "") + _, resp := Client.GetPostsForChannel(model.NewId(), 0, 60, "", false) CheckForbiddenStatus(t, resp) Client.Logout() - _, resp = Client.GetPostsForChannel(model.NewId(), 0, 60, "") + _, resp = Client.GetPostsForChannel(model.NewId(), 0, 60, "", false) CheckUnauthorizedStatus(t, resp) // more tests for next_post_id, prev_post_id, and order @@ -1131,11 +1131,11 @@ func TestGetPostsForChannel(t *testing.T) { var posts *model.PostList th.TestForAllClients(t, func(t *testing.T, c *model.Client4) { // get the system post IDs posted before the created posts above - posts, resp = c.GetPostsBefore(th.BasicChannel.Id, post1.Id, 0, 2, "") + posts, resp = c.GetPostsBefore(th.BasicChannel.Id, post1.Id, 0, 2, "", false) systemPostId1 := posts.Order[1] // similar to '/posts' - posts, resp = c.GetPostsForChannel(th.BasicChannel.Id, 0, 60, "") + posts, resp = c.GetPostsForChannel(th.BasicChannel.Id, 0, 60, "", false) CheckNoError(t, resp) require.Len(t, posts.Order, 12, "expected 12 posts") require.Equal(t, post10.Id, posts.Order[0], "posts not in order") @@ -1144,7 +1144,7 @@ func TestGetPostsForChannel(t *testing.T) { require.Equal(t, "", posts.PrevPostId, "should return an empty PrevPostId") // similar to '/posts?per_page=3' - posts, resp = c.GetPostsForChannel(th.BasicChannel.Id, 0, 3, "") + posts, resp = c.GetPostsForChannel(th.BasicChannel.Id, 0, 3, "", false) CheckNoError(t, resp) require.Len(t, posts.Order, 3, "expected 3 posts") require.Equal(t, post10.Id, posts.Order[0], "posts not in order") @@ -1153,7 +1153,7 @@ func TestGetPostsForChannel(t *testing.T) { require.Equal(t, post7.Id, posts.PrevPostId, "should return post7.Id as PrevPostId") // similar to '/posts?per_page=3&page=1' - posts, resp = c.GetPostsForChannel(th.BasicChannel.Id, 1, 3, "") + posts, resp = c.GetPostsForChannel(th.BasicChannel.Id, 1, 3, "", false) CheckNoError(t, resp) require.Len(t, posts.Order, 3, "expected 3 posts") require.Equal(t, post7.Id, posts.Order[0], "posts not in order") @@ -1162,7 +1162,7 @@ func TestGetPostsForChannel(t *testing.T) { require.Equal(t, post4.Id, posts.PrevPostId, "should return post4.Id as PrevPostId") // similar to '/posts?per_page=3&page=2' - posts, resp = c.GetPostsForChannel(th.BasicChannel.Id, 2, 3, "") + posts, resp = c.GetPostsForChannel(th.BasicChannel.Id, 2, 3, "", false) CheckNoError(t, resp) require.Len(t, posts.Order, 3, "expected 3 posts") require.Equal(t, post4.Id, posts.Order[0], "posts not in order") @@ -1171,7 +1171,7 @@ func TestGetPostsForChannel(t *testing.T) { require.Equal(t, post1.Id, posts.PrevPostId, "should return post1.Id as PrevPostId") // similar to '/posts?per_page=3&page=3' - posts, resp = c.GetPostsForChannel(th.BasicChannel.Id, 3, 3, "") + posts, resp = c.GetPostsForChannel(th.BasicChannel.Id, 3, 3, "", false) CheckNoError(t, resp) require.Len(t, posts.Order, 3, "expected 3 posts") require.Equal(t, post1.Id, posts.Order[0], "posts not in order") @@ -1180,7 +1180,7 @@ func TestGetPostsForChannel(t *testing.T) { require.Equal(t, "", posts.PrevPostId, "should return an empty PrevPostId") // similar to '/posts?per_page=3&page=4' - posts, resp = c.GetPostsForChannel(th.BasicChannel.Id, 4, 3, "") + posts, resp = c.GetPostsForChannel(th.BasicChannel.Id, 4, 3, "", false) CheckNoError(t, resp) require.Empty(t, posts.Order, "should return 0 post") require.Equal(t, "", posts.NextPostId, "should return an empty NextPostId") @@ -1390,7 +1390,7 @@ func TestGetPostsBefore(t *testing.T) { post4 := th.CreatePost() post5 := th.CreatePost() - posts, resp := Client.GetPostsBefore(th.BasicChannel.Id, post3.Id, 0, 100, "") + posts, resp := Client.GetPostsBefore(th.BasicChannel.Id, post3.Id, 0, 100, "", false) CheckNoError(t, resp) found := make([]bool, 2) @@ -1412,17 +1412,17 @@ func TestGetPostsBefore(t *testing.T) { require.Equal(t, post3.Id, posts.NextPostId, "should match NextPostId") require.Equal(t, "", posts.PrevPostId, "should match empty PrevPostId") - posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post4.Id, 1, 1, "") + posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post4.Id, 1, 1, "", false) CheckNoError(t, resp) require.Len(t, posts.Posts, 1, "too many posts returned") require.Equal(t, post2.Id, posts.Order[0], "should match returned post") require.Equal(t, post3.Id, posts.NextPostId, "should match NextPostId") require.Equal(t, post1.Id, posts.PrevPostId, "should match PrevPostId") - posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, "junk", 1, 1, "") + posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, "junk", 1, 1, "", false) CheckBadRequestStatus(t, resp) - posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post5.Id, 0, 3, "") + posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post5.Id, 0, 3, "", false) CheckNoError(t, resp) require.Len(t, posts.Posts, 3, "should match length of posts returned") require.Equal(t, post4.Id, posts.Order[0], "should match returned post") @@ -1431,12 +1431,12 @@ func TestGetPostsBefore(t *testing.T) { require.Equal(t, post1.Id, posts.PrevPostId, "should match PrevPostId") // get the system post IDs posted before the created posts above - posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post1.Id, 0, 2, "") + posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post1.Id, 0, 2, "", false) CheckNoError(t, resp) systemPostId2 := posts.Order[0] systemPostId1 := posts.Order[1] - posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post5.Id, 1, 3, "") + posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post5.Id, 1, 3, "", false) CheckNoError(t, resp) require.Len(t, posts.Posts, 3, "should match length of posts returned") require.Equal(t, post1.Id, posts.Order[0], "should match returned post") @@ -1454,7 +1454,7 @@ func TestGetPostsBefore(t *testing.T) { th.CreatePost() // post10 // similar to '/posts?before=post9' - posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post9.Id, 0, 60, "") + posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post9.Id, 0, 60, "", false) CheckNoError(t, resp) require.Len(t, posts.Order, 10, "expected 10 posts") require.Equal(t, post8.Id, posts.Order[0], "posts not in order") @@ -1463,7 +1463,7 @@ func TestGetPostsBefore(t *testing.T) { require.Equal(t, "", posts.PrevPostId, "should return an empty PrevPostId") // similar to '/posts?before=post9&per_page=3' - posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post9.Id, 0, 3, "") + posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post9.Id, 0, 3, "", false) CheckNoError(t, resp) require.Len(t, posts.Order, 3, "expected 3 posts") require.Equal(t, post8.Id, posts.Order[0], "posts not in order") @@ -1472,7 +1472,7 @@ func TestGetPostsBefore(t *testing.T) { require.Equal(t, post5.Id, posts.PrevPostId, "should return post5.Id as PrevPostId") // similar to '/posts?before=post9&per_page=3&page=1' - posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post9.Id, 1, 3, "") + posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post9.Id, 1, 3, "", false) CheckNoError(t, resp) require.Len(t, posts.Order, 3, "expected 3 posts") require.Equal(t, post5.Id, posts.Order[0], "posts not in order") @@ -1481,7 +1481,7 @@ func TestGetPostsBefore(t *testing.T) { require.Equal(t, post2.Id, posts.PrevPostId, "should return post2.Id as PrevPostId") // similar to '/posts?before=post9&per_page=3&page=2' - posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post9.Id, 2, 3, "") + posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post9.Id, 2, 3, "", false) CheckNoError(t, resp) require.Len(t, posts.Order, 3, "expected 3 posts") require.Equal(t, post2.Id, posts.Order[0], "posts not in order") @@ -1490,7 +1490,7 @@ func TestGetPostsBefore(t *testing.T) { require.Equal(t, systemPostId1, posts.PrevPostId, "should return systemPostId1 as PrevPostId") // similar to '/posts?before=post1&per_page=3' - posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post1.Id, 0, 3, "") + posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post1.Id, 0, 3, "", false) CheckNoError(t, resp) require.Len(t, posts.Order, 2, "expected 2 posts") require.Equal(t, systemPostId2, posts.Order[0], "posts not in order") @@ -1499,14 +1499,14 @@ func TestGetPostsBefore(t *testing.T) { require.Equal(t, "", posts.PrevPostId, "should return an empty PrevPostId") // similar to '/posts?before=systemPostId1' - posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, systemPostId1, 0, 60, "") + posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, systemPostId1, 0, 60, "", false) CheckNoError(t, resp) require.Empty(t, posts.Order, "should return 0 post") require.Equal(t, systemPostId1, posts.NextPostId, "should return systemPostId1 as NextPostId") require.Equal(t, "", posts.PrevPostId, "should return an empty PrevPostId") // similar to '/posts?before=systemPostId1&per_page=60&page=1' - posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, systemPostId1, 1, 60, "") + posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, systemPostId1, 1, 60, "", false) CheckNoError(t, resp) require.Empty(t, posts.Order, "should return 0 posts") require.Equal(t, "", posts.NextPostId, "should return an empty NextPostId") @@ -1514,7 +1514,7 @@ func TestGetPostsBefore(t *testing.T) { // similar to '/posts?before=non-existent-post' nonExistentPostId := model.NewId() - posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, nonExistentPostId, 0, 60, "") + posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, nonExistentPostId, 0, 60, "", false) CheckNoError(t, resp) require.Empty(t, posts.Order, "should return 0 post") require.Equal(t, nonExistentPostId, posts.NextPostId, "should return nonExistentPostId as NextPostId") @@ -1532,7 +1532,7 @@ func TestGetPostsAfter(t *testing.T) { post4 := th.CreatePost() post5 := th.CreatePost() - posts, resp := Client.GetPostsAfter(th.BasicChannel.Id, post3.Id, 0, 100, "") + posts, resp := Client.GetPostsAfter(th.BasicChannel.Id, post3.Id, 0, 100, "", false) CheckNoError(t, resp) found := make([]bool, 2) @@ -1552,17 +1552,17 @@ func TestGetPostsAfter(t *testing.T) { require.Equal(t, "", posts.NextPostId, "should match empty NextPostId") require.Equal(t, post3.Id, posts.PrevPostId, "should match PrevPostId") - posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, post2.Id, 1, 1, "") + posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, post2.Id, 1, 1, "", false) CheckNoError(t, resp) require.Len(t, posts.Posts, 1, "too many posts returned") require.Equal(t, post4.Id, posts.Order[0], "should match returned post") require.Equal(t, post5.Id, posts.NextPostId, "should match NextPostId") require.Equal(t, post3.Id, posts.PrevPostId, "should match PrevPostId") - posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, "junk", 1, 1, "") + posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, "junk", 1, 1, "", false) CheckBadRequestStatus(t, resp) - posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, post1.Id, 0, 3, "") + posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, post1.Id, 0, 3, "", false) CheckNoError(t, resp) require.Len(t, posts.Posts, 3, "should match length of posts returned") require.Equal(t, post4.Id, posts.Order[0], "should match returned post") @@ -1570,7 +1570,7 @@ func TestGetPostsAfter(t *testing.T) { require.Equal(t, post5.Id, posts.NextPostId, "should match NextPostId") require.Equal(t, post1.Id, posts.PrevPostId, "should match PrevPostId") - posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, post1.Id, 1, 3, "") + posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, post1.Id, 1, 3, "", false) CheckNoError(t, resp) require.Len(t, posts.Posts, 1, "should match length of posts returned") require.Equal(t, post5.Id, posts.Order[0], "should match returned post") @@ -1586,7 +1586,7 @@ func TestGetPostsAfter(t *testing.T) { post10 := th.CreatePost() // similar to '/posts?after=post2' - posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, post2.Id, 0, 60, "") + posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, post2.Id, 0, 60, "", false) CheckNoError(t, resp) require.Len(t, posts.Order, 8, "expected 8 posts") require.Equal(t, post10.Id, posts.Order[0], "should match order") @@ -1595,7 +1595,7 @@ func TestGetPostsAfter(t *testing.T) { require.Equal(t, post2.Id, posts.PrevPostId, "should return post2.Id as PrevPostId") // similar to '/posts?after=post2&per_page=3' - posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, post2.Id, 0, 3, "") + posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, post2.Id, 0, 3, "", false) CheckNoError(t, resp) require.Len(t, posts.Order, 3, "expected 3 posts") require.Equal(t, post5.Id, posts.Order[0], "should match order") @@ -1604,7 +1604,7 @@ func TestGetPostsAfter(t *testing.T) { require.Equal(t, post2.Id, posts.PrevPostId, "should return post2.Id as PrevPostId") // similar to '/posts?after=post2&per_page=3&page=1' - posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, post2.Id, 1, 3, "") + posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, post2.Id, 1, 3, "", false) CheckNoError(t, resp) require.Len(t, posts.Order, 3, "expected 3 posts") require.Equal(t, post8.Id, posts.Order[0], "should match order") @@ -1613,7 +1613,7 @@ func TestGetPostsAfter(t *testing.T) { require.Equal(t, post5.Id, posts.PrevPostId, "should return post5.Id as PrevPostId") // similar to '/posts?after=post2&per_page=3&page=2' - posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, post2.Id, 2, 3, "") + posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, post2.Id, 2, 3, "", false) CheckNoError(t, resp) require.Len(t, posts.Order, 2, "expected 2 posts") require.Equal(t, post10.Id, posts.Order[0], "should match order") @@ -1622,14 +1622,14 @@ func TestGetPostsAfter(t *testing.T) { require.Equal(t, post8.Id, posts.PrevPostId, "should return post8.Id as PrevPostId") // similar to '/posts?after=post10' - posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, post10.Id, 0, 60, "") + posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, post10.Id, 0, 60, "", false) CheckNoError(t, resp) require.Empty(t, posts.Order, "should return 0 post") require.Equal(t, "", posts.NextPostId, "should return an empty NextPostId") require.Equal(t, post10.Id, posts.PrevPostId, "should return post10.Id as PrevPostId") // similar to '/posts?after=post10&page=1' - posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, post10.Id, 1, 60, "") + posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, post10.Id, 1, 60, "", false) CheckNoError(t, resp) require.Empty(t, posts.Order, "should return 0 post") require.Equal(t, "", posts.NextPostId, "should return an empty NextPostId") @@ -1637,7 +1637,7 @@ func TestGetPostsAfter(t *testing.T) { // similar to '/posts?after=non-existent-post' nonExistentPostId := model.NewId() - posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, nonExistentPostId, 0, 60, "") + posts, resp = Client.GetPostsAfter(th.BasicChannel.Id, nonExistentPostId, 0, 60, "", false) CheckNoError(t, resp) require.Empty(t, posts.Order, "should return 0 post") require.Equal(t, "", posts.NextPostId, "should return an empty NextPostId") @@ -1720,13 +1720,13 @@ func TestGetPostsForChannelAroundLastUnread(t *testing.T) { } // Setting limit_after to zero should fail with a 400 BadRequest. - posts, resp := Client.GetPostsAroundLastUnread(userId, channelId, 20, 0) + posts, resp := Client.GetPostsAroundLastUnread(userId, channelId, 20, 0, false) require.Error(t, resp.Error) require.Equal(t, "api.context.invalid_url_param.app_error", resp.Error.Id) require.Equal(t, http.StatusBadRequest, resp.StatusCode) // All returned posts are all read by the user, since it's created by the user itself. - posts, resp = Client.GetPostsAroundLastUnread(userId, channelId, 20, 20) + posts, resp = Client.GetPostsAroundLastUnread(userId, channelId, 20, 20, false) CheckNoError(t, resp) require.Len(t, posts.Order, 12, "Should return 12 posts only since there's no unread post") @@ -1739,13 +1739,13 @@ func TestGetPostsForChannelAroundLastUnread(t *testing.T) { require.Nil(t, err) th.App.Srv().Store.Post().InvalidateLastPostTimeCache(channelId) - posts, resp = Client.GetPostsAroundLastUnread(userId, channelId, 20, 20) + posts, resp = Client.GetPostsAroundLastUnread(userId, channelId, 20, 20, false) CheckNoError(t, resp) require.Len(t, posts.Order, 12, "Should return 12 posts only since there's no unread post") // get the first system post generated before the created posts above - posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post1.Id, 0, 2, "") + posts, resp = Client.GetPostsBefore(th.BasicChannel.Id, post1.Id, 0, 2, "", false) CheckNoError(t, resp) systemPost0 := posts.Posts[posts.Order[0]] postIdNames[systemPost0.Id] = "system post 0" @@ -1760,7 +1760,7 @@ func TestGetPostsForChannelAroundLastUnread(t *testing.T) { require.Nil(t, err) th.App.Srv().Store.Post().InvalidateLastPostTimeCache(channelId) - posts, resp = Client.GetPostsAroundLastUnread(userId, channelId, 3, 3) + posts, resp = Client.GetPostsAroundLastUnread(userId, channelId, 3, 3, false) CheckNoError(t, resp) assertPostList(t, &model.PostList{ @@ -1784,7 +1784,7 @@ func TestGetPostsForChannelAroundLastUnread(t *testing.T) { require.Nil(t, err) th.App.Srv().Store.Post().InvalidateLastPostTimeCache(channelId) - posts, resp = Client.GetPostsAroundLastUnread(userId, channelId, 3, 3) + posts, resp = Client.GetPostsAroundLastUnread(userId, channelId, 3, 3, false) CheckNoError(t, resp) assertPostList(t, &model.PostList{ @@ -1811,7 +1811,7 @@ func TestGetPostsForChannelAroundLastUnread(t *testing.T) { require.Nil(t, err) th.App.Srv().Store.Post().InvalidateLastPostTimeCache(channelId) - posts, resp = Client.GetPostsAroundLastUnread(userId, channelId, 3, 3) + posts, resp = Client.GetPostsAroundLastUnread(userId, channelId, 3, 3, false) CheckNoError(t, resp) assertPostList(t, &model.PostList{ @@ -1836,7 +1836,7 @@ func TestGetPostsForChannelAroundLastUnread(t *testing.T) { require.Nil(t, err) th.App.Srv().Store.Post().InvalidateLastPostTimeCache(channelId) - posts, resp = Client.GetPostsAroundLastUnread(userId, channelId, 3, 3) + posts, resp = Client.GetPostsAroundLastUnread(userId, channelId, 3, 3, false) CheckNoError(t, resp) assertPostList(t, &model.PostList{ @@ -1876,7 +1876,7 @@ func TestGetPostsForChannelAroundLastUnread(t *testing.T) { require.Nil(t, err) th.App.Srv().Store.Post().InvalidateLastPostTimeCache(channelId) - posts, resp = Client.GetPostsAroundLastUnread(userId, channelId, 1, 2) + posts, resp = Client.GetPostsAroundLastUnread(userId, channelId, 1, 2, false) CheckNoError(t, resp) assertPostList(t, &model.PostList{ @@ -2051,11 +2051,11 @@ func TestGetPostThread(t *testing.T) { post := &model.Post{ChannelId: th.BasicChannel.Id, Message: "zz" + model.NewId() + "a", RootId: th.BasicPost.Id} post, _ = Client.CreatePost(post) - list, resp := Client.GetPostThread(th.BasicPost.Id, "") + list, resp := Client.GetPostThread(th.BasicPost.Id, "", false) CheckNoError(t, resp) var list2 *model.PostList - list2, resp = Client.GetPostThread(th.BasicPost.Id, resp.Etag) + list2, resp = Client.GetPostThread(th.BasicPost.Id, resp.Etag, false) CheckEtag(t, list2, resp) require.Equal(t, th.BasicPost.Id, list.Order[0], "wrong order") @@ -2065,34 +2065,34 @@ func TestGetPostThread(t *testing.T) { _, ok = list.Posts[post.Id] require.True(t, ok, "should have had post") - _, resp = Client.GetPostThread("junk", "") + _, resp = Client.GetPostThread("junk", "", false) CheckBadRequestStatus(t, resp) - _, resp = Client.GetPostThread(model.NewId(), "") + _, resp = Client.GetPostThread(model.NewId(), "", false) CheckNotFoundStatus(t, resp) Client.RemoveUserFromChannel(th.BasicChannel.Id, th.BasicUser.Id) // Channel is public, should be able to read post - _, resp = Client.GetPostThread(th.BasicPost.Id, "") + _, resp = Client.GetPostThread(th.BasicPost.Id, "", false) CheckNoError(t, resp) privatePost := th.CreatePostWithClient(Client, th.BasicPrivateChannel) - _, resp = Client.GetPostThread(privatePost.Id, "") + _, resp = Client.GetPostThread(privatePost.Id, "", false) CheckNoError(t, resp) Client.RemoveUserFromChannel(th.BasicPrivateChannel.Id, th.BasicUser.Id) // Channel is private, should not be able to read post - _, resp = Client.GetPostThread(privatePost.Id, "") + _, resp = Client.GetPostThread(privatePost.Id, "", false) CheckForbiddenStatus(t, resp) Client.Logout() - _, resp = Client.GetPostThread(model.NewId(), "") + _, resp = Client.GetPostThread(model.NewId(), "", false) CheckUnauthorizedStatus(t, resp) - _, resp = th.SystemAdminClient.GetPostThread(th.BasicPost.Id, "") + _, resp = th.SystemAdminClient.GetPostThread(th.BasicPost.Id, "", false) CheckNoError(t, resp) } diff --git a/api4/team_test.go b/api4/team_test.go index 3a128d08ab..a5cbb90ee8 100644 --- a/api4/team_test.go +++ b/api4/team_test.go @@ -2704,7 +2704,7 @@ func TestImportTeam(t *testing.T) { CheckNoError(t, resp) require.Equal(t, importedChannel.Name, "general", "names did not match expected: general") - posts, resp := th.SystemAdminClient.GetPostsForChannel(importedChannel.Id, 0, 60, "") + posts, resp := th.SystemAdminClient.GetPostsForChannel(importedChannel.Id, 0, 60, "", false) CheckNoError(t, resp) require.Equal(t, posts.Posts[posts.Order[3]].Message, "This is a test post to test the import process", "missing posts in the import process") }) diff --git a/app/app_iface.go b/app/app_iface.go index e432671afc..697c17c803 100644 --- a/app/app_iface.go +++ b/app/app_iface.go @@ -631,13 +631,13 @@ type AppIface interface { GetPostAfterTime(channelId string, time int64) (*model.Post, *model.AppError) GetPostIdAfterTime(channelId string, time int64) (string, *model.AppError) GetPostIdBeforeTime(channelId string, time int64) (string, *model.AppError) - GetPostThread(postId string, skipFetchThreads bool) (*model.PostList, *model.AppError) + GetPostThread(postId string, skipFetchThreads, collapsedThreads, collapsedThreadsExtended bool) (*model.PostList, *model.AppError) GetPosts(channelId string, offset int, limit int) (*model.PostList, *model.AppError) GetPostsAfterPost(options model.GetPostsOptions) (*model.PostList, *model.AppError) GetPostsAroundPost(before bool, options model.GetPostsOptions) (*model.PostList, *model.AppError) GetPostsBeforePost(options model.GetPostsOptions) (*model.PostList, *model.AppError) - GetPostsEtag(channelId string) string - GetPostsForChannelAroundLastUnread(channelId, userId string, limitBefore, limitAfter int, skipFetchThreads bool) (*model.PostList, *model.AppError) + GetPostsEtag(channelId string, collapsedThreads bool) string + GetPostsForChannelAroundLastUnread(channelId, userId string, limitBefore, limitAfter int, skipFetchThreads bool, collapsedThreads, collapsedThreadsExtended bool) (*model.PostList, *model.AppError) GetPostsPage(options model.GetPostsOptions) (*model.PostList, *model.AppError) GetPostsSince(options model.GetPostsSinceOptions) (*model.PostList, *model.AppError) GetPreferenceByCategoryAndNameForUser(userId string, category string, preferenceName string) (*model.Preference, *model.AppError) diff --git a/app/file.go b/app/file.go index 80b5749db7..d68576839d 100644 --- a/app/file.go +++ b/app/file.go @@ -388,7 +388,7 @@ func (a *App) MigrateFilenamesToFileInfos(post *model.Post) []*model.FileInfo { fileMigrationLock.Lock() defer fileMigrationLock.Unlock() - result, nErr := a.Srv().Store.Post().Get(post.Id, false) + result, nErr := a.Srv().Store.Post().Get(post.Id, false, false, false) if nErr != nil { mlog.Error("Unable to get post when migrating post to use FileInfos", mlog.Err(nErr), mlog.String("post_id", post.Id)) return []*model.FileInfo{} diff --git a/app/opentracing/opentracing_layer.go b/app/opentracing/opentracing_layer.go index c4e350b431..0e264c9a6f 100644 --- a/app/opentracing/opentracing_layer.go +++ b/app/opentracing/opentracing_layer.go @@ -7056,7 +7056,7 @@ func (a *OpenTracingAppLayer) GetPostIdBeforeTime(channelId string, time int64) return resultVar0, resultVar1 } -func (a *OpenTracingAppLayer) GetPostThread(postId string, skipFetchThreads bool) (*model.PostList, *model.AppError) { +func (a *OpenTracingAppLayer) GetPostThread(postId string, skipFetchThreads bool, collapsedThreads bool, collapsedThreadsExtended bool) (*model.PostList, *model.AppError) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetPostThread") @@ -7068,7 +7068,7 @@ func (a *OpenTracingAppLayer) GetPostThread(postId string, skipFetchThreads bool }() defer span.Finish() - resultVar0, resultVar1 := a.app.GetPostThread(postId, skipFetchThreads) + resultVar0, resultVar1 := a.app.GetPostThread(postId, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) if resultVar1 != nil { span.LogFields(spanlog.Error(resultVar1)) @@ -7166,7 +7166,7 @@ func (a *OpenTracingAppLayer) GetPostsBeforePost(options model.GetPostsOptions) return resultVar0, resultVar1 } -func (a *OpenTracingAppLayer) GetPostsEtag(channelId string) string { +func (a *OpenTracingAppLayer) GetPostsEtag(channelId string, collapsedThreads bool) string { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetPostsEtag") @@ -7178,12 +7178,12 @@ func (a *OpenTracingAppLayer) GetPostsEtag(channelId string) string { }() defer span.Finish() - resultVar0 := a.app.GetPostsEtag(channelId) + resultVar0 := a.app.GetPostsEtag(channelId, collapsedThreads) return resultVar0 } -func (a *OpenTracingAppLayer) GetPostsForChannelAroundLastUnread(channelId string, userId string, limitBefore int, limitAfter int, skipFetchThreads bool) (*model.PostList, *model.AppError) { +func (a *OpenTracingAppLayer) GetPostsForChannelAroundLastUnread(channelId string, userId string, limitBefore int, limitAfter int, skipFetchThreads bool, collapsedThreads bool, collapsedThreadsExtended bool) (*model.PostList, *model.AppError) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetPostsForChannelAroundLastUnread") @@ -7195,7 +7195,7 @@ func (a *OpenTracingAppLayer) GetPostsForChannelAroundLastUnread(channelId strin }() defer span.Finish() - resultVar0, resultVar1 := a.app.GetPostsForChannelAroundLastUnread(channelId, userId, limitBefore, limitAfter, skipFetchThreads) + resultVar0, resultVar1 := a.app.GetPostsForChannelAroundLastUnread(channelId, userId, limitBefore, limitAfter, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) if resultVar1 != nil { span.LogFields(spanlog.Error(resultVar1)) diff --git a/app/plugin_api.go b/app/plugin_api.go index 7be60b03f8..cdc92c2b88 100644 --- a/app/plugin_api.go +++ b/app/plugin_api.go @@ -574,7 +574,7 @@ func (api *PluginAPI) DeletePost(postId string) *model.AppError { } func (api *PluginAPI) GetPostThread(postId string) (*model.PostList, *model.AppError) { - return api.app.GetPostThread(postId, false) + return api.app.GetPostThread(postId, false, false, false) } func (api *PluginAPI) GetPost(postId string) (*model.Post, *model.AppError) { diff --git a/app/post.go b/app/post.go index d39106504f..3807d7570b 100644 --- a/app/post.go +++ b/app/post.go @@ -185,7 +185,7 @@ func (a *App) CreatePost(post *model.Post, channel *model.Channel, triggerWebhoo if len(post.RootId) > 0 { pchan = make(chan store.StoreResult, 1) go func() { - r, pErr := a.Srv().Store.Post().Get(post.RootId, false) + r, pErr := a.Srv().Store.Post().Get(post.RootId, false, false, false) pchan <- store.StoreResult{Data: r, NErr: pErr} close(pchan) }() @@ -542,7 +542,7 @@ func (a *App) DeleteEphemeralPost(userId, postId string) { func (a *App) UpdatePost(post *model.Post, safeUpdate bool) (*model.Post, *model.AppError) { post.SanitizeProps() - postLists, nErr := a.Srv().Store.Post().Get(post.Id, false) + postLists, nErr := a.Srv().Store.Post().Get(post.Id, false, false, false) if nErr != nil { var nfErr *store.ErrNotFound var invErr *store.ErrInvalidInput @@ -718,8 +718,8 @@ func (a *App) GetPosts(channelId string, offset int, limit int) (*model.PostList return postList, nil } -func (a *App) GetPostsEtag(channelId string) string { - return a.Srv().Store.Post().GetEtag(channelId, true) +func (a *App) GetPostsEtag(channelId string, collapsedThreads bool) string { + return a.Srv().Store.Post().GetEtag(channelId, true, collapsedThreads) } func (a *App) GetPostsSince(options model.GetPostsSinceOptions) (*model.PostList, *model.AppError) { @@ -746,8 +746,8 @@ func (a *App) GetSinglePost(postId string) (*model.Post, *model.AppError) { return post, nil } -func (a *App) GetPostThread(postId string, skipFetchThreads bool) (*model.PostList, *model.AppError) { - posts, err := a.Srv().Store.Post().Get(postId, skipFetchThreads) +func (a *App) GetPostThread(postId string, skipFetchThreads, collapsedThreads, collapsedThreadsExtended bool) (*model.PostList, *model.AppError) { + posts, err := a.Srv().Store.Post().Get(postId, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) if err != nil { var nfErr *store.ErrNotFound var invErr *store.ErrInvalidInput @@ -792,7 +792,7 @@ func (a *App) GetFlaggedPostsForChannel(userId, channelId string, offset int, li } func (a *App) GetPermalinkPost(postId string, userId string) (*model.PostList, *model.AppError) { - list, nErr := a.Srv().Store.Post().Get(postId, false) + list, nErr := a.Srv().Store.Post().Get(postId, false, false, false) if nErr != nil { var nfErr *store.ErrNotFound var invErr *store.ErrInvalidInput @@ -975,7 +975,7 @@ func (a *App) AddCursorIdsForPostList(originalList *model.PostList, afterPost, b originalList.NextPostId = nextPostId originalList.PrevPostId = prevPostId } -func (a *App) GetPostsForChannelAroundLastUnread(channelId, userId string, limitBefore, limitAfter int, skipFetchThreads bool) (*model.PostList, *model.AppError) { +func (a *App) GetPostsForChannelAroundLastUnread(channelId, userId string, limitBefore, limitAfter int, skipFetchThreads bool, collapsedThreads, collapsedThreadsExtended bool) (*model.PostList, *model.AppError) { var member *model.ChannelMember var err *model.AppError if member, err = a.GetChannelMember(channelId, userId); err != nil { @@ -991,7 +991,7 @@ func (a *App) GetPostsForChannelAroundLastUnread(channelId, userId string, limit return model.NewPostList(), nil } - postList, err := a.GetPostThread(lastUnreadPostId, skipFetchThreads) + postList, err := a.GetPostThread(lastUnreadPostId, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) if err != nil { return nil, err } @@ -999,13 +999,13 @@ func (a *App) GetPostsForChannelAroundLastUnread(channelId, userId string, limit // channel organically, those replies will be added below. postList.Order = []string{lastUnreadPostId} - if postListBefore, err := a.GetPostsBeforePost(model.GetPostsOptions{ChannelId: channelId, PostId: lastUnreadPostId, Page: PageDefault, PerPage: limitBefore, SkipFetchThreads: skipFetchThreads}); err != nil { + if postListBefore, err := a.GetPostsBeforePost(model.GetPostsOptions{ChannelId: channelId, PostId: lastUnreadPostId, Page: PageDefault, PerPage: limitBefore, SkipFetchThreads: skipFetchThreads, CollapsedThreads: collapsedThreads, CollapsedThreadsExtended: collapsedThreadsExtended}); err != nil { return nil, err } else if postListBefore != nil { postList.Extend(postListBefore) } - if postListAfter, err := a.GetPostsAfterPost(model.GetPostsOptions{ChannelId: channelId, PostId: lastUnreadPostId, Page: PageDefault, PerPage: limitAfter - 1, SkipFetchThreads: skipFetchThreads}); err != nil { + if postListAfter, err := a.GetPostsAfterPost(model.GetPostsOptions{ChannelId: channelId, PostId: lastUnreadPostId, Page: PageDefault, PerPage: limitAfter - 1, SkipFetchThreads: skipFetchThreads, CollapsedThreads: collapsedThreads, CollapsedThreadsExtended: collapsedThreadsExtended}); err != nil { return nil, err } else if postListAfter != nil { postList.Extend(postListAfter) @@ -1434,7 +1434,7 @@ func (a *App) countMentionsFromPost(user *model.User, post *model.Post) (int, *m // A mapping of thread root IDs to whether or not a post in that thread mentions the user mentionedByThread := make(map[string]bool) - thread, err := a.GetPostThread(post.Id, false) + thread, err := a.GetPostThread(post.Id, false, false, false) if err != nil { return 0, err } diff --git a/app/post_test.go b/app/post_test.go index abe2c2e8a4..a484e2ff0b 100644 --- a/app/post_test.go +++ b/app/post_test.go @@ -1921,3 +1921,52 @@ func TestThreadMembership(t *testing.T) { require.Len(t, memberships, 2) }) } + +func TestCollapsedThreadFetch(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + th.App.UpdateConfig(func(cfg *model.Config) { + *cfg.ServiceSettings.ThreadAutoFollow = true + *cfg.ServiceSettings.CollapsedThreads = model.COLLAPSED_THREADS_DEFAULT_ON + }) + user1 := th.BasicUser + user2 := th.BasicUser2 + + t.Run("should only return root posts, enriched", func(t *testing.T) { + channel := th.CreateChannel(th.BasicTeam) + th.AddUserToChannel(user2, channel) + defer th.App.DeleteChannel(channel, user1.Id) + + postRoot, err := th.App.CreatePost(&model.Post{ + UserId: user1.Id, + ChannelId: channel.Id, + Message: "root post", + }, channel, false, true) + require.Nil(t, err) + + _, err = th.App.CreatePost(&model.Post{ + UserId: user1.Id, + ChannelId: channel.Id, + RootId: postRoot.Id, + Message: fmt.Sprintf("@%s", user2.Username), + }, channel, false, true) + require.Nil(t, err) + thread, nErr := th.App.Srv().Store.Thread().Get(postRoot.Id) + require.Nil(t, nErr) + require.Len(t, thread.Participants, 2) + th.App.MarkChannelAsUnreadFromPost(postRoot.Id, user1.Id) + l, err := th.App.GetPostsForChannelAroundLastUnread(channel.Id, user1.Id, 10, 10, true, true, false) + require.Nil(t, err) + require.Len(t, l.Order, 1) + require.EqualValues(t, 1, l.Posts[postRoot.Id].ReplyCount) + require.EqualValues(t, []string{user1.Id, user2.Id}, []string{l.Posts[postRoot.Id].Participants[0].Id, l.Posts[postRoot.Id].Participants[1].Id}) + require.Empty(t, l.Posts[postRoot.Id].Participants[0].Email) + require.NotZero(t, l.Posts[postRoot.Id].LastReplyAt) + + // try extended fetch + l, err = th.App.GetPostsForChannelAroundLastUnread(channel.Id, user1.Id, 10, 10, true, true, true) + require.Nil(t, err) + require.Len(t, l.Order, 1) + require.NotEmpty(t, l.Posts[postRoot.Id].Participants[0].Email) + }) +} diff --git a/config/client.go b/config/client.go index 4cc11ae694..9166b7e85c 100644 --- a/config/client.go +++ b/config/client.go @@ -135,6 +135,7 @@ func GenerateClientConfig(c *model.Config, telemetryID string, license *model.Li props["CustomUrlSchemes"] = strings.Join(c.DisplaySettings.CustomUrlSchemes, ",") props["IsDefaultMarketplace"] = strconv.FormatBool(*c.PluginSettings.MarketplaceUrl == model.PLUGIN_SETTINGS_DEFAULT_MARKETPLACE_URL) props["ExperimentalSharedChannels"] = "false" + props["CollapsedThreads"] = *c.ServiceSettings.CollapsedThreads if license != nil { props["ExperimentalHideTownSquareinLHS"] = strconv.FormatBool(*c.TeamSettings.ExperimentalHideTownSquareinLHS) diff --git a/model/client4.go b/model/client4.go index b160a4a79d..4d715ac39c 100644 --- a/model/client4.go +++ b/model/client4.go @@ -2903,8 +2903,12 @@ func (c *Client4) DeletePost(postId string) (bool, *Response) { } // GetPostThread gets a post with all the other posts in the same thread. -func (c *Client4) GetPostThread(postId string, etag string) (*PostList, *Response) { - r, err := c.DoApiGet(c.GetPostRoute(postId)+"/thread", etag) +func (c *Client4) GetPostThread(postId string, etag string, collapsedThreads bool) (*PostList, *Response) { + url := c.GetPostRoute(postId) + "/thread" + if collapsedThreads { + url += "?collapsedThreads=true" + } + r, err := c.DoApiGet(url, etag) if err != nil { return nil, BuildErrorResponse(r, err) } @@ -2913,8 +2917,11 @@ func (c *Client4) GetPostThread(postId string, etag string) (*PostList, *Respons } // GetPostsForChannel gets a page of posts with an array for ordering for a channel. -func (c *Client4) GetPostsForChannel(channelId string, page, perPage int, etag string) (*PostList, *Response) { +func (c *Client4) GetPostsForChannel(channelId string, page, perPage int, etag string, collapsedThreads bool) (*PostList, *Response) { query := fmt.Sprintf("?page=%v&per_page=%v", page, perPage) + if collapsedThreads { + query += "&collapsedThreads=true" + } r, err := c.DoApiGet(c.GetChannelRoute(channelId)+"/posts"+query, etag) if err != nil { return nil, BuildErrorResponse(r, err) @@ -2965,8 +2972,11 @@ func (c *Client4) GetFlaggedPostsForUserInChannel(userId string, channelId strin } // GetPostsSince gets posts created after a specified time as Unix time in milliseconds. -func (c *Client4) GetPostsSince(channelId string, time int64) (*PostList, *Response) { +func (c *Client4) GetPostsSince(channelId string, time int64, collapsedThreads bool) (*PostList, *Response) { query := fmt.Sprintf("?since=%v", time) + if collapsedThreads { + query += "&collapsedThreads=true" + } r, err := c.DoApiGet(c.GetChannelRoute(channelId)+"/posts"+query, "") if err != nil { return nil, BuildErrorResponse(r, err) @@ -2976,8 +2986,11 @@ func (c *Client4) GetPostsSince(channelId string, time int64) (*PostList, *Respo } // GetPostsAfter gets a page of posts that were posted after the post provided. -func (c *Client4) GetPostsAfter(channelId, postId string, page, perPage int, etag string) (*PostList, *Response) { +func (c *Client4) GetPostsAfter(channelId, postId string, page, perPage int, etag string, collapsedThreads bool) (*PostList, *Response) { query := fmt.Sprintf("?page=%v&per_page=%v&after=%v", page, perPage, postId) + if collapsedThreads { + query += "&collapsedThreads=true" + } r, err := c.DoApiGet(c.GetChannelRoute(channelId)+"/posts"+query, etag) if err != nil { return nil, BuildErrorResponse(r, err) @@ -2987,8 +3000,11 @@ func (c *Client4) GetPostsAfter(channelId, postId string, page, perPage int, eta } // GetPostsBefore gets a page of posts that were posted before the post provided. -func (c *Client4) GetPostsBefore(channelId, postId string, page, perPage int, etag string) (*PostList, *Response) { +func (c *Client4) GetPostsBefore(channelId, postId string, page, perPage int, etag string, collapsedThreads bool) (*PostList, *Response) { query := fmt.Sprintf("?page=%v&per_page=%v&before=%v", page, perPage, postId) + if collapsedThreads { + query += "&collapsedThreads=true" + } r, err := c.DoApiGet(c.GetChannelRoute(channelId)+"/posts"+query, etag) if err != nil { return nil, BuildErrorResponse(r, err) @@ -2998,8 +3014,11 @@ func (c *Client4) GetPostsBefore(channelId, postId string, page, perPage int, et } // GetPostsAroundLastUnread gets a list of posts around last unread post by a user in a channel. -func (c *Client4) GetPostsAroundLastUnread(userId, channelId string, limitBefore, limitAfter int) (*PostList, *Response) { +func (c *Client4) GetPostsAroundLastUnread(userId, channelId string, limitBefore, limitAfter int, collapsedThreads bool) (*PostList, *Response) { query := fmt.Sprintf("?limit_before=%v&limit_after=%v", limitBefore, limitAfter) + if collapsedThreads { + query += "&collapsedThreads=true" + } r, err := c.DoApiGet(c.GetUserRoute(userId)+c.GetChannelRoute(channelId)+"/posts/unread"+query, "") if err != nil { return nil, BuildErrorResponse(r, err) diff --git a/model/post.go b/model/post.go index 8133e6ae25..5d59bc58c5 100644 --- a/model/post.go +++ b/model/post.go @@ -99,8 +99,10 @@ type Post struct { HasReactions bool `json:"has_reactions,omitempty"` // Transient data populated before sending a post to the client - ReplyCount int64 `json:"reply_count" db:"-"` - Metadata *PostMetadata `json:"metadata,omitempty" db:"-"` + ReplyCount int64 `json:"reply_count" db:"-"` + LastReplyAt int64 `json:"last_reply_at" db:"-"` + Participants []*User `json:"participants" db:"-"` + Metadata *PostMetadata `json:"metadata,omitempty" db:"-"` } type PostEphemeral struct { @@ -201,6 +203,8 @@ func (o *Post) ShallowCopy(dst *Post) error { dst.PendingPostId = o.PendingPostId dst.HasReactions = o.HasReactions dst.ReplyCount = o.ReplyCount + dst.Participants = o.Participants + dst.LastReplyAt = o.LastReplyAt dst.Metadata = o.Metadata return nil } @@ -225,17 +229,21 @@ func (o *Post) ToUnsanitizedJson() string { } type GetPostsSinceOptions struct { - ChannelId string - Time int64 - SkipFetchThreads bool + ChannelId string + Time int64 + SkipFetchThreads bool + CollapsedThreads bool + CollapsedThreadsExtended bool } type GetPostsOptions struct { - ChannelId string - PostId string - Page int - PerPage int - SkipFetchThreads bool + ChannelId string + PostId string + Page int + PerPage int + SkipFetchThreads bool + CollapsedThreads bool + CollapsedThreadsExtended bool } func PostFromJson(data io.Reader) *Post { @@ -357,6 +365,9 @@ func (o *Post) SanitizeProps() { o.DelProp(member) } } + for _, p := range o.Participants { + p.Sanitize(map[string]bool{}) + } } func (o *Post) PreSave() { diff --git a/store/localcachelayer/main_test.go b/store/localcachelayer/main_test.go index 4090809c8f..096b6a29c0 100644 --- a/store/localcachelayer/main_test.go +++ b/store/localcachelayer/main_test.go @@ -110,8 +110,8 @@ func getMockStore() *mocks.Store { mockPostStoreEtagResult := fmt.Sprintf("%v.%v", model.CurrentVersion, 1) mockPostStore.On("ClearCaches") mockPostStore.On("InvalidateLastPostTimeCache", "channelId") - mockPostStore.On("GetEtag", "channelId", true).Return(mockPostStoreEtagResult) - mockPostStore.On("GetEtag", "channelId", false).Return(mockPostStoreEtagResult) + mockPostStore.On("GetEtag", "channelId", true, false).Return(mockPostStoreEtagResult) + mockPostStore.On("GetEtag", "channelId", false, false).Return(mockPostStoreEtagResult) mockPostStore.On("GetPostsSince", mockPostStoreOptions, true).Return(model.NewPostList(), nil) mockPostStore.On("GetPostsSince", mockPostStoreOptions, false).Return(model.NewPostList(), nil) mockStore.On("Post").Return(&mockPostStore) diff --git a/store/localcachelayer/post_layer.go b/store/localcachelayer/post_layer.go index 4293e73e69..125d4a4f28 100644 --- a/store/localcachelayer/post_layer.go +++ b/store/localcachelayer/post_layer.go @@ -59,7 +59,7 @@ func (s LocalCachePostStore) InvalidateLastPostTimeCache(channelId string) { } } -func (s LocalCachePostStore) GetEtag(channelId string, allowFromCache bool) string { +func (s LocalCachePostStore) GetEtag(channelId string, allowFromCache, collapsedThreads bool) string { if allowFromCache { var lastTime int64 if err := s.rootStore.doStandardReadCache(s.rootStore.lastPostTimeCache, channelId, &lastTime); err == nil { @@ -67,7 +67,7 @@ func (s LocalCachePostStore) GetEtag(channelId string, allowFromCache bool) stri } } - result := s.PostStore.GetEtag(channelId, allowFromCache) + result := s.PostStore.GetEtag(channelId, allowFromCache, collapsedThreads) splittedResult := strings.Split(result, ".") diff --git a/store/localcachelayer/post_layer_test.go b/store/localcachelayer/post_layer_test.go index cf0a1e7c55..c60d633539 100644 --- a/store/localcachelayer/post_layer_test.go +++ b/store/localcachelayer/post_layer_test.go @@ -36,11 +36,11 @@ func TestPostStoreLastPostTimeCache(t *testing.T) { expectedResult := fmt.Sprintf("%v.%v", model.CurrentVersion, fakeLastTime) - etag := cachedStore.Post().GetEtag(channelId, true) + etag := cachedStore.Post().GetEtag(channelId, true, false) assert.Equal(t, etag, expectedResult) mockStore.Post().(*mocks.PostStore).AssertNumberOfCalls(t, "GetEtag", 1) - etag = cachedStore.Post().GetEtag(channelId, true) + etag = cachedStore.Post().GetEtag(channelId, true, false) assert.Equal(t, etag, expectedResult) mockStore.Post().(*mocks.PostStore).AssertNumberOfCalls(t, "GetEtag", 1) }) @@ -51,9 +51,9 @@ func TestPostStoreLastPostTimeCache(t *testing.T) { cachedStore, err := NewLocalCacheLayer(mockStore, nil, nil, mockCacheProvider) require.NoError(t, err) - cachedStore.Post().GetEtag(channelId, true) + cachedStore.Post().GetEtag(channelId, true, false) mockStore.Post().(*mocks.PostStore).AssertNumberOfCalls(t, "GetEtag", 1) - cachedStore.Post().GetEtag(channelId, false) + cachedStore.Post().GetEtag(channelId, false, false) mockStore.Post().(*mocks.PostStore).AssertNumberOfCalls(t, "GetEtag", 2) }) @@ -63,10 +63,10 @@ func TestPostStoreLastPostTimeCache(t *testing.T) { cachedStore, err := NewLocalCacheLayer(mockStore, nil, nil, mockCacheProvider) require.NoError(t, err) - cachedStore.Post().GetEtag(channelId, true) + cachedStore.Post().GetEtag(channelId, true, false) mockStore.Post().(*mocks.PostStore).AssertNumberOfCalls(t, "GetEtag", 1) cachedStore.Post().InvalidateLastPostTimeCache(channelId) - cachedStore.Post().GetEtag(channelId, true) + cachedStore.Post().GetEtag(channelId, true, false) mockStore.Post().(*mocks.PostStore).AssertNumberOfCalls(t, "GetEtag", 2) }) @@ -76,10 +76,10 @@ func TestPostStoreLastPostTimeCache(t *testing.T) { cachedStore, err := NewLocalCacheLayer(mockStore, nil, nil, mockCacheProvider) require.NoError(t, err) - cachedStore.Post().GetEtag(channelId, true) + cachedStore.Post().GetEtag(channelId, true, false) mockStore.Post().(*mocks.PostStore).AssertNumberOfCalls(t, "GetEtag", 1) cachedStore.Post().ClearCaches() - cachedStore.Post().GetEtag(channelId, true) + cachedStore.Post().GetEtag(channelId, true, false) mockStore.Post().(*mocks.PostStore).AssertNumberOfCalls(t, "GetEtag", 2) }) diff --git a/store/opentracinglayer/opentracinglayer.go b/store/opentracinglayer/opentracinglayer.go index bbd8d78ff2..0b579ab5ac 100644 --- a/store/opentracinglayer/opentracinglayer.go +++ b/store/opentracinglayer/opentracinglayer.go @@ -4898,7 +4898,7 @@ func (s *OpenTracingLayerPostStore) Delete(postId string, time int64, deleteByID return err } -func (s *OpenTracingLayerPostStore) Get(id string, skipFetchThreads bool) (*model.PostList, error) { +func (s *OpenTracingLayerPostStore) Get(id string, skipFetchThreads bool, collapsedThreads bool, collapsedThreadsExtended bool) (*model.PostList, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "PostStore.Get") s.Root.Store.SetContext(newCtx) @@ -4907,7 +4907,7 @@ func (s *OpenTracingLayerPostStore) Get(id string, skipFetchThreads bool) (*mode }() defer span.Finish() - result, err := s.PostStore.Get(id, skipFetchThreads) + result, err := s.PostStore.Get(id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) if err != nil { span.LogFields(spanlog.Error(err)) ext.Error.Set(span, true) @@ -4934,7 +4934,7 @@ func (s *OpenTracingLayerPostStore) GetDirectPostParentsForExportAfter(limit int return result, err } -func (s *OpenTracingLayerPostStore) GetEtag(channelId string, allowFromCache bool) string { +func (s *OpenTracingLayerPostStore) GetEtag(channelId string, allowFromCache bool, collapsedThreads bool) string { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "PostStore.GetEtag") s.Root.Store.SetContext(newCtx) @@ -4943,7 +4943,7 @@ func (s *OpenTracingLayerPostStore) GetEtag(channelId string, allowFromCache boo }() defer span.Finish() - result := s.PostStore.GetEtag(channelId, allowFromCache) + result := s.PostStore.GetEtag(channelId, allowFromCache, collapsedThreads) return result } diff --git a/store/retrylayer/retrylayer.go b/store/retrylayer/retrylayer.go index 8a096f62e8..a73e1a7d0f 100644 --- a/store/retrylayer/retrylayer.go +++ b/store/retrylayer/retrylayer.go @@ -5302,11 +5302,11 @@ func (s *RetryLayerPostStore) Delete(postId string, time int64, deleteByID strin } -func (s *RetryLayerPostStore) Get(id string, skipFetchThreads bool) (*model.PostList, error) { +func (s *RetryLayerPostStore) Get(id string, skipFetchThreads bool, collapsedThreads bool, collapsedThreadsExtended bool) (*model.PostList, error) { tries := 0 for { - result, err := s.PostStore.Get(id, skipFetchThreads) + result, err := s.PostStore.Get(id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) if err == nil { return result, nil } @@ -5342,9 +5342,9 @@ func (s *RetryLayerPostStore) GetDirectPostParentsForExportAfter(limit int, afte } -func (s *RetryLayerPostStore) GetEtag(channelId string, allowFromCache bool) string { +func (s *RetryLayerPostStore) GetEtag(channelId string, allowFromCache bool, collapsedThreads bool) string { - return s.PostStore.GetEtag(channelId, allowFromCache) + return s.PostStore.GetEtag(channelId, allowFromCache, collapsedThreads) } diff --git a/store/searchlayer/post_layer.go b/store/searchlayer/post_layer.go index f6207bec54..57311cd1f2 100644 --- a/store/searchlayer/post_layer.go +++ b/store/searchlayer/post_layer.go @@ -108,7 +108,7 @@ func (s SearchPostStore) Delete(postId string, date int64, deletedByID string) e err := s.PostStore.Delete(postId, date, deletedByID) if err == nil { - postList, err2 := s.PostStore.Get(postId, true) + postList, err2 := s.PostStore.Get(postId, true, false, false) if postList != nil && len(postList.Order) > 0 { if err2 != nil { s.deletePostIndex(postList.Posts[postList.Order[0]]) diff --git a/store/sqlstore/post_store.go b/store/sqlstore/post_store.go index bfd1a480ca..df1bc00f66 100644 --- a/store/sqlstore/post_store.go +++ b/store/sqlstore/post_store.go @@ -30,6 +30,12 @@ type SqlPostStore struct { maxPostSizeCached int } +type postWithExtra struct { + ThreadReplyCount int64 + ThreadParticipants model.StringArray + model.Post +} + func (s *SqlPostStore) ClearCaches() { } @@ -418,8 +424,40 @@ func (s *SqlPostStore) GetFlaggedPostsForChannel(userId, channelId string, offse return pl, nil } +func (s *SqlPostStore) getPostWithCollapsedThreads(id string, extended bool) (*model.PostList, error) { + if len(id) == 0 { + return nil, store.NewErrInvalidInput("Post", "id", id) + } -func (s *SqlPostStore) Get(id string, skipFetchThreads bool) (*model.PostList, error) { + var columns []string + for _, c := range postSliceColumns() { + columns = append(columns, "Posts."+c) + } + columns = append(columns, "COALESCE(Threads.ReplyCount, 0) as ThreadReplyCount", "COALESCE(Threads.LastReplyAt, 0) as LastReplyAt", "COALESCE(Threads.Participants, '[]') as ThreadParticipants") + var post postWithExtra + + postFetchQuery, args, _ := s.getQueryBuilder(). + Select(columns...). + From("Posts"). + LeftJoin("Threads ON Threads.PostId = Id"). + Where(sq.Eq{"DeleteAt": 0}). + Where(sq.Eq{"Id": id}).ToSql() + + err := s.GetReplica().SelectOne(&post, postFetchQuery, args...) + if err != nil { + if err == sql.ErrNoRows { + return nil, store.NewErrNotFound("Post", id) + } + + return nil, errors.Wrapf(err, "failed to get Post with id=%s", id) + } + return s.prepareThreadedResponse([]*postWithExtra{&post}, extended, false) +} + +func (s *SqlPostStore) Get(id string, skipFetchThreads, collapsedThreads, collapsedThreadsExtended bool) (*model.PostList, error) { + if collapsedThreads { + return s.getPostWithCollapsedThreads(id, collapsedThreadsExtended) + } pl := model.NewPostList() if len(id) == 0 { @@ -484,9 +522,14 @@ type etagPosts struct { func (s *SqlPostStore) InvalidateLastPostTimeCache(channelId string) { } -func (s *SqlPostStore) GetEtag(channelId string, allowFromCache bool) string { +func (s *SqlPostStore) GetEtag(channelId string, allowFromCache, collapsedThreads bool) string { + q := s.getQueryBuilder().Select("Id", "UpdateAt").From("Posts").Where(sq.Eq{"ChannelId": channelId}).OrderBy("UpdateAt DESC").Limit(1) + if collapsedThreads { + q.Where(sq.Eq{"RootId": ""}) + } + sql, args, _ := q.ToSql() var et etagPosts - err := s.GetReplica().SelectOne(&et, "SELECT Id, UpdateAt FROM Posts WHERE ChannelId = :ChannelId ORDER BY UpdateAt DESC LIMIT 1", map[string]interface{}{"ChannelId": channelId}) + err := s.GetReplica().SelectOne(&et, sql, args...) var result string if err != nil { result = fmt.Sprintf("%v.%v", model.CurrentVersion, model.GetMillis()) @@ -616,10 +659,102 @@ func (s *SqlPostStore) PermanentDeleteByChannel(channelId string) error { return nil } +func (s *SqlPostStore) prepareThreadedResponse(posts []*postWithExtra, extended, reversed bool) (*model.PostList, error) { + list := model.NewPostList() + var userIds []string + userIdMap := map[string]bool{} + for _, thread := range posts { + for _, participantId := range thread.ThreadParticipants { + if _, ok := userIdMap[participantId]; !ok { + userIdMap[participantId] = true + userIds = append(userIds, participantId) + } + } + } + var users []*model.User + if extended { + var err error + users, err = s.User().GetProfileByIds(userIds, &store.UserGetByIdsOpts{}, true) + if err != nil { + return nil, err + } + } else { + for _, userId := range userIds { + users = append(users, &model.User{Id: userId}) + } + } + processPost := func(p *postWithExtra) error { + p.Post.ReplyCount = p.ThreadReplyCount + for _, th := range p.ThreadParticipants { + var participant *model.User + for _, u := range users { + if u.Id == th { + participant = u + break + } + } + if participant == nil { + return errors.New("cannot find thread participant with id=" + th) + } + p.Post.Participants = append(p.Post.Participants, participant) + } + return nil + } + + l := len(posts) + for i := range posts { + idx := i + // We need to flip the order if we selected backwards + + if reversed { + idx = l - i - 1 + } + if err := processPost(posts[idx]); err != nil { + return nil, err + } + list.AddPost(&posts[idx].Post) + list.AddOrder(posts[idx].Id) + } + + return list, nil +} + +func (s *SqlPostStore) getPostsCollapsedThreads(options model.GetPostsOptions) (*model.PostList, error) { + var columns []string + for _, c := range postSliceColumns() { + columns = append(columns, "Posts."+c) + } + columns = append(columns, "COALESCE(Threads.ReplyCount, 0) as ThreadReplyCount", "COALESCE(Threads.LastReplyAt, 0) as LastReplyAt", "COALESCE(Threads.Participants, '[]') as ThreadParticipants") + var posts []*postWithExtra + offset := options.PerPage * options.Page + + postFetchQuery, args, _ := s.getQueryBuilder(). + Select(columns...). + From("Posts"). + LeftJoin("Threads ON Threads.PostId = Id"). + Where(sq.Eq{"DeleteAt": 0}). + Where(sq.Eq{"Posts.ChannelId": options.ChannelId}). + Where(sq.Eq{"RootId": ""}). + Limit(uint64(options.PerPage)). + Offset(uint64(offset)). + OrderBy("CreateAt DESC").ToSql() + + _, err := s.GetReplica().Select(&posts, postFetchQuery, args...) + + if err != nil { + return nil, errors.Wrapf(err, "failed to find Posts with channelId=%s", options.ChannelId) + } + + return s.prepareThreadedResponse(posts, options.CollapsedThreadsExtended, false) +} + func (s *SqlPostStore) GetPosts(options model.GetPostsOptions, _ bool) (*model.PostList, error) { if options.PerPage > 1000 { return nil, store.NewErrInvalidInput("Post", "", options.PerPage) } + if options.CollapsedThreads { + return s.getPostsCollapsedThreads(options) + } offset := options.PerPage * options.Page rpc := make(chan store.StoreResult, 1) @@ -664,7 +799,36 @@ func (s *SqlPostStore) GetPosts(options model.GetPostsOptions, _ bool) (*model.P return list, nil } +func (s *SqlPostStore) getPostsSinceCollapsedThreads(options model.GetPostsSinceOptions) (*model.PostList, error) { + var columns []string + for _, c := range postSliceColumns() { + columns = append(columns, "Posts."+c) + } + columns = append(columns, "COALESCE(Threads.ReplyCount, 0) as ThreadReplyCount", "COALESCE(Threads.LastReplyAt, 0) as LastReplyAt", "COALESCE(Threads.Participants, '[]') as ThreadParticipants") + var posts []*postWithExtra + + postFetchQuery, args, _ := s.getQueryBuilder(). + Select(columns...). + From("Posts"). + LeftJoin("Threads ON Threads.PostId = Id"). + Where(sq.Eq{"DeleteAt": 0}). + Where(sq.Eq{"Posts.ChannelId": options.ChannelId}). + Where(sq.Gt{"UpdateAt": options.Time}). + Where(sq.Eq{"RootId": ""}). + OrderBy("CreateAt DESC").ToSql() + + _, err := s.GetReplica().Select(&posts, postFetchQuery, args...) + + if err != nil { + return nil, errors.Wrapf(err, "failed to find Posts with channelId=%s", options.ChannelId) + } + return s.prepareThreadedResponse(posts, options.CollapsedThreadsExtended, false) +} + func (s *SqlPostStore) GetPostsSince(options model.GetPostsSinceOptions, allowFromCache bool) (*model.PostList, error) { + if options.CollapsedThreads { + return s.getPostsSinceCollapsedThreads(options) + } var posts []*model.Post replyCountQuery1 := "" @@ -753,7 +917,8 @@ func (s *SqlPostStore) getPostsAround(before bool, options model.GetPostsOptions } offset := options.Page * options.PerPage - var posts, parents []*model.Post + var posts []*postWithExtra + var parents []*model.Post var direction string var sort string @@ -771,20 +936,30 @@ func (s *SqlPostStore) getPostsAround(before bool, options model.GetPostsOptions if s.DriverName() == model.DATABASE_DRIVER_MYSQL { table += " USE INDEX(idx_posts_channel_id_delete_at_create_at)" } - + columns := []string{"p.*"} + if options.CollapsedThreads { + columns = append(columns, "COALESCE(Threads.ReplyCount, 0) as ThreadReplyCount", "COALESCE(Threads.LastReplyAt, 0) as LastReplyAt", "COALESCE(Threads.Participants, '[]') as ThreadParticipants") + } + query := s.getQueryBuilder().Select(columns...) replyCountSubQuery := s.getQueryBuilder().Select("COUNT(Posts.Id)").From("Posts").Where(sq.Expr("Posts.RootId = (CASE WHEN p.RootId = '' THEN p.Id ELSE p.RootId END) AND Posts.DeleteAt = 0")) - query := s.getQueryBuilder().Select("p.*") - query = query.Column(sq.Alias(replyCountSubQuery, "ReplyCount")) + + conditions := sq.And{ + sq.Expr(`CreateAt `+direction+` (SELECT CreateAt FROM Posts WHERE Id = ?)`, options.PostId), + sq.Eq{"p.ChannelId": options.ChannelId}, + sq.Eq{"DeleteAt": int(0)}, + } + if options.CollapsedThreads { + conditions = append(conditions, sq.Eq{"RootId": ""}) + query = query.LeftJoin("Threads ON Threads.PostId = p.Id") + } else { + query = query.Column(sq.Alias(replyCountSubQuery, "ReplyCount")) + } query = query.From(table). - Where(sq.And{ - sq.Expr(`CreateAt `+direction+` (SELECT CreateAt FROM Posts WHERE Id = ?)`, options.PostId), - sq.Eq{"ChannelId": options.ChannelId}, - sq.Eq{"DeleteAt": int(0)}, - }). + Where(conditions). // Adding ChannelId and DeleteAt order columns // to let mysql choose the "idx_posts_channel_id_delete_at_create_at" index always. // See MM-24170. - OrderBy("ChannelId", "DeleteAt", "CreateAt "+sort). + OrderBy("p.ChannelId", "DeleteAt", "CreateAt "+sort). Limit(uint64(options.PerPage)). Offset(uint64(offset)) @@ -797,7 +972,7 @@ func (s *SqlPostStore) getPostsAround(before bool, options model.GetPostsOptions return nil, errors.Wrapf(err, "failed to find Posts with channelId=%s", options.ChannelId) } - if len(posts) > 0 { + if !options.CollapsedThreads && len(posts) > 0 { rootIds := []string{} for _, post := range posts { rootIds = append(rootIds, post.Id) @@ -822,31 +997,20 @@ func (s *SqlPostStore) getPostsAround(before bool, options model.GetPostsOptions }). OrderBy("CreateAt DESC") - rootQueryString, rootArgs, err := rootQuery.ToSql() + rootQueryString, rootArgs, nErr := rootQuery.ToSql() - if err != nil { - return nil, errors.Wrap(err, "post_tosql") + if nErr != nil { + return nil, errors.Wrap(nErr, "post_tosql") } - _, err = s.GetMaster().Select(&parents, rootQueryString, rootArgs...) - if err != nil { - return nil, errors.Wrapf(err, "failed to find Posts with channelId=%s", options.ChannelId) + _, nErr = s.GetMaster().Select(&parents, rootQueryString, rootArgs...) + if nErr != nil { + return nil, errors.Wrapf(nErr, "failed to find Posts with channelId=%s", options.ChannelId) } } - list := model.NewPostList() - - // We need to flip the order if we selected backwards - if before { - for _, p := range posts { - list.AddPost(p) - list.AddOrder(p.Id) - } - } else { - l := len(posts) - for i := range posts { - list.AddPost(posts[l-i-1]) - list.AddOrder(posts[l-i-1].Id) - } + list, err := s.prepareThreadedResponse(posts, options.CollapsedThreadsExtended, !before) + if err != nil { + return nil, err } for _, p := range parents { diff --git a/store/sqlstore/thread_store.go b/store/sqlstore/thread_store.go index 4b749bc13e..e43a3bf8bb 100644 --- a/store/sqlstore/thread_store.go +++ b/store/sqlstore/thread_store.go @@ -234,8 +234,9 @@ func (s *SqlThreadStore) GetThreadsForUser(userId, teamId string, opts model.Get } var users []*model.User if opts.Extended { - query, args, _ := s.getQueryBuilder().Select("*").From("Users").Where(sq.Eq{"Id": userIds}).ToSql() - if _, err := s.GetReplica().Select(&users, query, args...); err != nil { + var err error + users, err = s.User().GetProfileByIds(userIds, &store.UserGetByIdsOpts{}, true) + if err != nil { return nil, errors.Wrapf(err, "failed to get threads for user id=%s", userId) } } else { @@ -414,6 +415,18 @@ func (s *SqlThreadStore) CreateMembershipIfNeeded(userId, postId string, followi LastUpdated: now, UnreadMentions: int64(mentions), }) + if err != nil { + return err + } + + thread, err := s.Get(postId) + if err != nil { + return err + } + if !thread.Participants.Contains(userId) { + thread.Participants = append(thread.Participants, userId) + _, err = s.Update(thread) + } return err } diff --git a/store/sqlstore/upgrade.go b/store/sqlstore/upgrade.go index ab0100e7d5..4c1882c2d6 100644 --- a/store/sqlstore/upgrade.go +++ b/store/sqlstore/upgrade.go @@ -964,6 +964,7 @@ func upgradeDatabaseToVersion532(sqlStore *SqlStore) { // if shouldPerformUpgrade(sqlStore, Version5310, Version5320) { // allow 10 files per post sqlStore.AlterColumnTypeIfExists("Posts", "FileIds", "text", "varchar(300)") + sqlStore.CreateColumnIfNotExists("ThreadMemberships", "UnreadMentions", "bigint", "bigint", "0") sqlStore.CreateColumnIfNotExistsNoDefault("Channels", "Shared", "tinyint(1)", "boolean") sqlStore.CreateColumnIfNotExists("ThreadMemberships", "UnreadMentions", "bigint", "bigint", "0") // saveSchemaVersion(sqlStore, Version5320) diff --git a/store/store.go b/store/store.go index d9826a43d1..85895adea7 100644 --- a/store/store.go +++ b/store/store.go @@ -272,7 +272,7 @@ type PostStore interface { SaveMultiple(posts []*model.Post) ([]*model.Post, int, error) Save(post *model.Post) (*model.Post, error) Update(newPost *model.Post, oldPost *model.Post) (*model.Post, error) - Get(id string, skipFetchThreads bool) (*model.PostList, error) + Get(id string, skipFetchThreads, collapsedThreads, collapsedThreadsExtended bool) (*model.PostList, error) GetSingle(id string) (*model.Post, error) Delete(postId string, time int64, deleteByID string) error PermanentDeleteByUser(userId string) error @@ -288,7 +288,7 @@ type PostStore interface { GetPostAfterTime(channelId string, time int64) (*model.Post, error) GetPostIdAfterTime(channelId string, time int64) (string, error) GetPostIdBeforeTime(channelId string, time int64) (string, error) - GetEtag(channelId string, allowFromCache bool) string + GetEtag(channelId string, allowFromCache bool, collapsedThreads bool) string Search(teamId string, userId string, params *model.SearchParams) (*model.PostList, error) AnalyticsUserCountsWithPostsByDay(teamId string) (model.AnalyticsRows, error) AnalyticsPostCountsByDay(options *model.AnalyticsPostCountsOptions) (model.AnalyticsRows, error) diff --git a/store/storetest/mocks/PostStore.go b/store/storetest/mocks/PostStore.go index cb3c70d1e7..ac451738e0 100644 --- a/store/storetest/mocks/PostStore.go +++ b/store/storetest/mocks/PostStore.go @@ -100,13 +100,13 @@ func (_m *PostStore) Delete(postId string, time int64, deleteByID string) error return r0 } -// Get provides a mock function with given fields: id, skipFetchThreads -func (_m *PostStore) Get(id string, skipFetchThreads bool) (*model.PostList, error) { - ret := _m.Called(id, skipFetchThreads) +// Get provides a mock function with given fields: id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended +func (_m *PostStore) Get(id string, skipFetchThreads bool, collapsedThreads bool, collapsedThreadsExtended bool) (*model.PostList, error) { + ret := _m.Called(id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) var r0 *model.PostList - if rf, ok := ret.Get(0).(func(string, bool) *model.PostList); ok { - r0 = rf(id, skipFetchThreads) + if rf, ok := ret.Get(0).(func(string, bool, bool, bool) *model.PostList); ok { + r0 = rf(id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*model.PostList) @@ -114,8 +114,8 @@ func (_m *PostStore) Get(id string, skipFetchThreads bool) (*model.PostList, err } var r1 error - if rf, ok := ret.Get(1).(func(string, bool) error); ok { - r1 = rf(id, skipFetchThreads) + if rf, ok := ret.Get(1).(func(string, bool, bool, bool) error); ok { + r1 = rf(id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) } else { r1 = ret.Error(1) } @@ -146,13 +146,13 @@ func (_m *PostStore) GetDirectPostParentsForExportAfter(limit int, afterId strin return r0, r1 } -// GetEtag provides a mock function with given fields: channelId, allowFromCache -func (_m *PostStore) GetEtag(channelId string, allowFromCache bool) string { - ret := _m.Called(channelId, allowFromCache) +// GetEtag provides a mock function with given fields: channelId, allowFromCache, collapsedThreads +func (_m *PostStore) GetEtag(channelId string, allowFromCache bool, collapsedThreads bool) string { + ret := _m.Called(channelId, allowFromCache, collapsedThreads) var r0 string - if rf, ok := ret.Get(0).(func(string, bool) string); ok { - r0 = rf(channelId, allowFromCache) + if rf, ok := ret.Get(0).(func(string, bool, bool) string); ok { + r0 = rf(channelId, allowFromCache, collapsedThreads) } else { r0 = ret.Get(0).(string) } diff --git a/store/storetest/post_store.go b/store/storetest/post_store.go index 2d244b9786..7cad4835b7 100644 --- a/store/storetest/post_store.go +++ b/store/storetest/post_store.go @@ -407,23 +407,23 @@ func testPostStoreGet(t *testing.T, ss store.Store) { o1.UserId = model.NewId() o1.Message = "zz" + model.NewId() + "b" - etag1 := ss.Post().GetEtag(o1.ChannelId, false) + etag1 := ss.Post().GetEtag(o1.ChannelId, false, false) require.Equal(t, 0, strings.Index(etag1, model.CurrentVersion+"."), "Invalid Etag") o1, err := ss.Post().Save(o1) require.Nil(t, err) - etag2 := ss.Post().GetEtag(o1.ChannelId, false) + etag2 := ss.Post().GetEtag(o1.ChannelId, false, false) require.Equal(t, 0, strings.Index(etag2, fmt.Sprintf("%v.%v", model.CurrentVersion, o1.UpdateAt)), "Invalid Etag") - r1, err := ss.Post().Get(o1.Id, false) + r1, err := ss.Post().Get(o1.Id, false, false, false) require.Nil(t, err) require.Equal(t, r1.Posts[o1.Id].CreateAt, o1.CreateAt, "invalid returned post") - _, err = ss.Post().Get("123", false) + _, err = ss.Post().Get("123", false, false, false) require.NotNil(t, err, "Missing id should have failed") - _, err = ss.Post().Get("", false) + _, err = ss.Post().Get("", false, false, false) require.NotNil(t, err, "should fail for blank post ids") } @@ -468,15 +468,15 @@ func testPostStoreUpdate(t *testing.T, ss store.Store) { o3, err = ss.Post().Save(o3) require.Nil(t, err) - r1, err := ss.Post().Get(o1.Id, false) + r1, err := ss.Post().Get(o1.Id, false, false, false) require.Nil(t, err) ro1 := r1.Posts[o1.Id] - r2, err := ss.Post().Get(o1.Id, false) + r2, err := ss.Post().Get(o1.Id, false, false, false) require.Nil(t, err) ro2 := r2.Posts[o2.Id] - r3, err := ss.Post().Get(o3.Id, false) + r3, err := ss.Post().Get(o3.Id, false, false, false) require.Nil(t, err) ro3 := r3.Posts[o3.Id] @@ -487,7 +487,7 @@ func testPostStoreUpdate(t *testing.T, ss store.Store) { _, err = ss.Post().Update(o1a, ro1) require.Nil(t, err) - r1, err = ss.Post().Get(o1.Id, false) + r1, err = ss.Post().Get(o1.Id, false, false, false) require.Nil(t, err) ro1a := r1.Posts[o1.Id] @@ -498,7 +498,7 @@ func testPostStoreUpdate(t *testing.T, ss store.Store) { _, err = ss.Post().Update(o2a, ro2) require.Nil(t, err) - r2, err = ss.Post().Get(o1.Id, false) + r2, err = ss.Post().Get(o1.Id, false, false, false) require.Nil(t, err) ro2a := r2.Posts[o2.Id] @@ -509,7 +509,7 @@ func testPostStoreUpdate(t *testing.T, ss store.Store) { _, err = ss.Post().Update(o3a, ro3) require.Nil(t, err) - r3, err = ss.Post().Get(o3.Id, false) + r3, err = ss.Post().Get(o3.Id, false, false, false) require.Nil(t, err) ro3a := r3.Posts[o3.Id] @@ -525,7 +525,7 @@ func testPostStoreUpdate(t *testing.T, ss store.Store) { }) require.Nil(t, err) - r4, err := ss.Post().Get(o4.Id, false) + r4, err := ss.Post().Get(o4.Id, false, false, false) require.Nil(t, err) ro4 := r4.Posts[o4.Id] @@ -535,7 +535,7 @@ func testPostStoreUpdate(t *testing.T, ss store.Store) { _, err = ss.Post().Update(o4a, ro4) require.Nil(t, err) - r4, err = ss.Post().Get(o4.Id, false) + r4, err = ss.Post().Get(o4.Id, false, false, false) require.Nil(t, err) ro4a := r4.Posts[o4.Id] @@ -550,13 +550,13 @@ func testPostStoreDelete(t *testing.T, ss store.Store) { o1.Message = "zz" + model.NewId() + "b" deleteByID := model.NewId() - etag1 := ss.Post().GetEtag(o1.ChannelId, false) + etag1 := ss.Post().GetEtag(o1.ChannelId, false, false) require.Equal(t, 0, strings.Index(etag1, model.CurrentVersion+"."), "Invalid Etag") o1, err := ss.Post().Save(o1) require.Nil(t, err) - r1, err := ss.Post().Get(o1.Id, false) + r1, err := ss.Post().Get(o1.Id, false, false, false) require.Nil(t, err) require.Equal(t, r1.Posts[o1.Id].CreateAt, o1.CreateAt, "invalid returned post") @@ -569,10 +569,10 @@ func testPostStoreDelete(t *testing.T, ss store.Store) { assert.Equal(t, deleteByID, actual, "Expected (*Post).Props[model.POST_PROPS_DELETE_BY] to be %v but got %v.", deleteByID, actual) - r3, err := ss.Post().Get(o1.Id, false) + r3, err := ss.Post().Get(o1.Id, false, false, false) require.NotNil(t, err, "Missing id should have failed - PostList %v", r3) - etag2 := ss.Post().GetEtag(o1.ChannelId, false) + etag2 := ss.Post().GetEtag(o1.ChannelId, false, false) require.Equal(t, 0, strings.Index(etag2, model.CurrentVersion+"."), "Invalid Etag") } @@ -596,10 +596,10 @@ func testPostStoreDelete1Level(t *testing.T, ss store.Store) { err = ss.Post().Delete(o1.Id, model.GetMillis(), "") require.Nil(t, err) - _, err = ss.Post().Get(o1.Id, false) + _, err = ss.Post().Get(o1.Id, false, false, false) require.NotNil(t, err, "Deleted id should have failed") - _, err = ss.Post().Get(o2.Id, false) + _, err = ss.Post().Get(o2.Id, false, false, false) require.NotNil(t, err, "Deleted id should have failed") } @@ -639,16 +639,16 @@ func testPostStoreDelete2Level(t *testing.T, ss store.Store) { err = ss.Post().Delete(o1.Id, model.GetMillis(), "") require.Nil(t, err) - _, err = ss.Post().Get(o1.Id, false) + _, err = ss.Post().Get(o1.Id, false, false, false) require.NotNil(t, err, "Deleted id should have failed") - _, err = ss.Post().Get(o2.Id, false) + _, err = ss.Post().Get(o2.Id, false, false, false) require.NotNil(t, err, "Deleted id should have failed") - _, err = ss.Post().Get(o3.Id, false) + _, err = ss.Post().Get(o3.Id, false, false, false) require.NotNil(t, err, "Deleted id should have failed") - _, err = ss.Post().Get(o4.Id, false) + _, err = ss.Post().Get(o4.Id, false, false, false) require.Nil(t, err) } @@ -679,16 +679,16 @@ func testPostStorePermDelete1Level(t *testing.T, ss store.Store) { err2 := ss.Post().PermanentDeleteByUser(o2.UserId) require.Nil(t, err2) - _, err = ss.Post().Get(o1.Id, false) + _, err = ss.Post().Get(o1.Id, false, false, false) require.Nil(t, err, "Deleted id shouldn't have failed") - _, err = ss.Post().Get(o2.Id, false) + _, err = ss.Post().Get(o2.Id, false, false, false) require.NotNil(t, err, "Deleted id should have failed") err = ss.Post().PermanentDeleteByChannel(o3.ChannelId) require.Nil(t, err) - _, err = ss.Post().Get(o3.Id, false) + _, err = ss.Post().Get(o3.Id, false, false, false) require.NotNil(t, err, "Deleted id should have failed") } @@ -719,13 +719,13 @@ func testPostStorePermDelete1Level2(t *testing.T, ss store.Store) { err2 := ss.Post().PermanentDeleteByUser(o1.UserId) require.Nil(t, err2) - _, err = ss.Post().Get(o1.Id, false) + _, err = ss.Post().Get(o1.Id, false, false, false) require.NotNil(t, err, "Deleted id should have failed") - _, err = ss.Post().Get(o2.Id, false) + _, err = ss.Post().Get(o2.Id, false, false, false) require.NotNil(t, err, "Deleted id should have failed") - _, err = ss.Post().Get(o3.Id, false) + _, err = ss.Post().Get(o3.Id, false, false, false) require.Nil(t, err, "Deleted id should have failed") } @@ -755,7 +755,7 @@ func testPostStoreGetWithChildren(t *testing.T, ss store.Store) { o3, err = ss.Post().Save(o3) require.Nil(t, err) - pl, err := ss.Post().Get(o1.Id, false) + pl, err := ss.Post().Get(o1.Id, false, false, false) require.Nil(t, err) require.Len(t, pl.Posts, 3, "invalid returned post") @@ -763,7 +763,7 @@ func testPostStoreGetWithChildren(t *testing.T, ss store.Store) { dErr := ss.Post().Delete(o3.Id, model.GetMillis(), "") require.Nil(t, dErr) - pl, err = ss.Post().Get(o1.Id, false) + pl, err = ss.Post().Get(o1.Id, false, false, false) require.Nil(t, err) require.Len(t, pl.Posts, 2, "invalid returned post") @@ -771,7 +771,7 @@ func testPostStoreGetWithChildren(t *testing.T, ss store.Store) { dErr = ss.Post().Delete(o2.Id, model.GetMillis(), "") require.Nil(t, dErr) - pl, err = ss.Post().Get(o1.Id, false) + pl, err = ss.Post().Get(o1.Id, false, false, false) require.Nil(t, err) require.Len(t, pl.Posts, 1, "invalid returned post") @@ -1181,6 +1181,101 @@ func testPostStoreGetPostsBeforeAfter(t *testing.T, ss store.Store) { }, postList.Posts) }) }) + t.Run("with threads (collapsedThreads)", func(t *testing.T) { + channelId := model.NewId() + userId := model.NewId() + + // This creates a series of posts that looks like: + // post1 + // post2 + // post3 (in response to post1) + // post4 (in response to post2) + // post5 + // post6 (in response to post2) + + post1, err := ss.Post().Save(&model.Post{ + ChannelId: channelId, + UserId: userId, + Message: "post1", + }) + require.Nil(t, err) + post1.ReplyCount = 1 + time.Sleep(time.Millisecond) + + post2, err := ss.Post().Save(&model.Post{ + ChannelId: channelId, + UserId: userId, + Message: "post2", + }) + require.Nil(t, err) + post2.ReplyCount = 2 + time.Sleep(time.Millisecond) + + post3, err := ss.Post().Save(&model.Post{ + ChannelId: channelId, + UserId: userId, + ParentId: post1.Id, + RootId: post1.Id, + Message: "post3", + }) + require.Nil(t, err) + post3.ReplyCount = 1 + time.Sleep(time.Millisecond) + + post4, err := ss.Post().Save(&model.Post{ + ChannelId: channelId, + UserId: userId, + RootId: post2.Id, + ParentId: post2.Id, + Message: "post4", + }) + require.Nil(t, err) + post4.ReplyCount = 2 + time.Sleep(time.Millisecond) + + post5, err := ss.Post().Save(&model.Post{ + ChannelId: channelId, + UserId: userId, + Message: "post5", + }) + require.Nil(t, err) + time.Sleep(time.Millisecond) + + post6, err := ss.Post().Save(&model.Post{ + ChannelId: channelId, + UserId: userId, + ParentId: post2.Id, + RootId: post2.Id, + Message: "post6", + }) + post6.ReplyCount = 2 + require.Nil(t, err) + + // Adding a post to a thread changes the UpdateAt timestamp of the parent post + post1.UpdateAt = post3.UpdateAt + post2.UpdateAt = post6.UpdateAt + + t.Run("should return each root post before a post", func(t *testing.T) { + postList, err := ss.Post().GetPostsBefore(model.GetPostsOptions{ChannelId: channelId, PostId: post4.Id, PerPage: 2, CollapsedThreads: true}) + assert.Nil(t, err) + + assert.Equal(t, []string{post2.Id, post1.Id}, postList.Order) + }) + + t.Run("should return each root post before a post with limit", func(t *testing.T) { + postList, err := ss.Post().GetPostsBefore(model.GetPostsOptions{ChannelId: channelId, PostId: post4.Id, PerPage: 1, CollapsedThreads: true}) + assert.Nil(t, err) + + assert.Equal(t, []string{post2.Id}, postList.Order) + }) + + t.Run("should return each root after a post", func(t *testing.T) { + postList, err := ss.Post().GetPostsAfter(model.GetPostsOptions{ChannelId: channelId, PostId: post4.Id, PerPage: 2, CollapsedThreads: true}) + require.Nil(t, err) + + assert.Equal(t, []string{post5.Id}, postList.Order) + }) + }) } func testPostStoreGetPostsSince(t *testing.T, ss store.Store) { @@ -2146,23 +2241,23 @@ func testPostStoreOverwriteMultiple(t *testing.T, ss store.Store) { }) require.Nil(t, err) - r1, err := ss.Post().Get(o1.Id, false) + r1, err := ss.Post().Get(o1.Id, false, false, false) require.Nil(t, err) ro1 := r1.Posts[o1.Id] - r2, err := ss.Post().Get(o2.Id, false) + r2, err := ss.Post().Get(o2.Id, false, false, false) require.Nil(t, err) ro2 := r2.Posts[o2.Id] - r3, err := ss.Post().Get(o3.Id, false) + r3, err := ss.Post().Get(o3.Id, false, false, false) require.Nil(t, err) ro3 := r3.Posts[o3.Id] - r4, err := ss.Post().Get(o4.Id, false) + r4, err := ss.Post().Get(o4.Id, false, false, false) require.Nil(t, err) ro4 := r4.Posts[o4.Id] - r5, err := ss.Post().Get(o5.Id, false) + r5, err := ss.Post().Get(o5.Id, false, false, false) require.Nil(t, err) ro5 := r5.Posts[o5.Id] @@ -2188,15 +2283,15 @@ func testPostStoreOverwriteMultiple(t *testing.T, ss store.Store) { require.Nil(t, err) require.Equal(t, -1, errIdx) - r1, nErr := ss.Post().Get(o1.Id, false) + r1, nErr := ss.Post().Get(o1.Id, false, false, false) require.Nil(t, nErr) ro1a := r1.Posts[o1.Id] - r2, nErr = ss.Post().Get(o1.Id, false) + r2, nErr = ss.Post().Get(o1.Id, false, false, false) require.Nil(t, nErr) ro2a := r2.Posts[o2.Id] - r3, nErr = ss.Post().Get(o3.Id, false) + r3, nErr = ss.Post().Get(o3.Id, false, false, false) require.Nil(t, nErr) ro3a := r3.Posts[o3.Id] @@ -2218,11 +2313,11 @@ func testPostStoreOverwriteMultiple(t *testing.T, ss store.Store) { require.Nil(t, err) require.Equal(t, -1, errIdx) - r4, nErr := ss.Post().Get(o4.Id, false) + r4, nErr := ss.Post().Get(o4.Id, false, false, false) require.Nil(t, nErr) ro4a := r4.Posts[o4.Id] - r5, nErr = ss.Post().Get(o5.Id, false) + r5, nErr = ss.Post().Get(o5.Id, false, false, false) require.Nil(t, nErr) ro5a := r5.Posts[o5.Id] @@ -2265,19 +2360,19 @@ func testPostStoreOverwrite(t *testing.T, ss store.Store) { }) require.Nil(t, err) - r1, err := ss.Post().Get(o1.Id, false) + r1, err := ss.Post().Get(o1.Id, false, false, false) require.Nil(t, err) ro1 := r1.Posts[o1.Id] - r2, err := ss.Post().Get(o2.Id, false) + r2, err := ss.Post().Get(o2.Id, false, false, false) require.Nil(t, err) ro2 := r2.Posts[o2.Id] - r3, err := ss.Post().Get(o3.Id, false) + r3, err := ss.Post().Get(o3.Id, false, false, false) require.Nil(t, err) ro3 := r3.Posts[o3.Id] - r4, err := ss.Post().Get(o4.Id, false) + r4, err := ss.Post().Get(o4.Id, false, false, false) require.Nil(t, err) ro4 := r4.Posts[o4.Id] @@ -2302,15 +2397,15 @@ func testPostStoreOverwrite(t *testing.T, ss store.Store) { _, err = ss.Post().Overwrite(o3a) require.Nil(t, err) - r1, err = ss.Post().Get(o1.Id, false) + r1, err = ss.Post().Get(o1.Id, false, false, false) require.Nil(t, err) ro1a := r1.Posts[o1.Id] - r2, err = ss.Post().Get(o1.Id, false) + r2, err = ss.Post().Get(o1.Id, false, false, false) require.Nil(t, err) ro2a := r2.Posts[o2.Id] - r3, err = ss.Post().Get(o3.Id, false) + r3, err = ss.Post().Get(o3.Id, false, false, false) require.Nil(t, err) ro3a := r3.Posts[o3.Id] @@ -2326,7 +2421,7 @@ func testPostStoreOverwrite(t *testing.T, ss store.Store) { _, err = ss.Post().Overwrite(o4a) require.Nil(t, err) - r4, err = ss.Post().Get(o4.Id, false) + r4, err = ss.Post().Get(o4.Id, false, false, false) require.Nil(t, err) ro4a := r4.Posts[o4.Id] @@ -2357,15 +2452,15 @@ func testPostStoreGetPostsByIds(t *testing.T, ss store.Store) { o3, err = ss.Post().Save(o3) require.Nil(t, err) - r1, err := ss.Post().Get(o1.Id, false) + r1, err := ss.Post().Get(o1.Id, false, false, false) require.Nil(t, err) ro1 := r1.Posts[o1.Id] - r2, err := ss.Post().Get(o2.Id, false) + r2, err := ss.Post().Get(o2.Id, false, false, false) require.Nil(t, err) ro2 := r2.Posts[o2.Id] - r3, err := ss.Post().Get(o3.Id, false) + r3, err := ss.Post().Get(o3.Id, false, false, false) require.Nil(t, err) ro3 := r3.Posts[o3.Id] @@ -2472,13 +2567,13 @@ func testPostStorePermanentDeleteBatch(t *testing.T, ss store.Store) { _, err = ss.Post().PermanentDeleteBatch(2000, 1000) require.Nil(t, err) - _, err = ss.Post().Get(o1.Id, false) + _, err = ss.Post().Get(o1.Id, false, false, false) require.NotNil(t, err, "Should have not found post 1 after purge") - _, err = ss.Post().Get(o2.Id, false) + _, err = ss.Post().Get(o2.Id, false, false, false) require.NotNil(t, err, "Should have not found post 2 after purge") - _, err = ss.Post().Get(o3.Id, false) + _, err = ss.Post().Get(o3.Id, false, false, false) require.Nil(t, err, "Should have not found post 3 after purge") } diff --git a/store/storetest/reaction_store.go b/store/storetest/reaction_store.go index 5eadfc9fab..1d04e8ce29 100644 --- a/store/storetest/reaction_store.go +++ b/store/storetest/reaction_store.go @@ -50,7 +50,7 @@ func testReactionSave(t *testing.T, ss store.Store) { assert.Equal(t, saved.EmojiName, reaction1.EmojiName, "should've saved reaction emoji_name and returned it") var secondUpdateAt int64 - postList, err := ss.Post().Get(reaction1.PostId, false) + postList, err := ss.Post().Get(reaction1.PostId, false, false, false) require.Nil(t, err) assert.True(t, postList.Posts[post.Id].HasReactions, "should've set HasReactions = true on post") @@ -74,7 +74,7 @@ func testReactionSave(t *testing.T, ss store.Store) { _, nErr = ss.Reaction().Save(reaction2) require.Nil(t, nErr) - postList, err = ss.Post().Get(reaction2.PostId, false) + postList, err = ss.Post().Get(reaction2.PostId, false, false, false) require.Nil(t, err) assert.NotEqual(t, postList.Posts[post.Id].UpdateAt, secondUpdateAt, "should've marked post as updated even if HasReactions doesn't change") @@ -123,7 +123,7 @@ func testReactionDelete(t *testing.T, ss store.Store) { _, nErr := ss.Reaction().Save(reaction) require.Nil(t, nErr) - result, err := ss.Post().Get(reaction.PostId, false) + result, err := ss.Post().Get(reaction.PostId, false, false, false) require.Nil(t, err) firstUpdateAt := result.Posts[post.Id].UpdateAt @@ -136,7 +136,7 @@ func testReactionDelete(t *testing.T, ss store.Store) { assert.Empty(t, reactions, "should've deleted reaction") - postList, err := ss.Post().Get(post.Id, false) + postList, err := ss.Post().Get(post.Id, false, false, false) require.Nil(t, err) assert.False(t, postList.Posts[post.Id].HasReactions, "should've set HasReactions = false on post") @@ -297,15 +297,15 @@ func testReactionDeleteAllWithEmojiName(t *testing.T, ss store.Store) { assert.Empty(t, returned, "should've only removed reactions with emoji name") // check that the posts are updated - postList, err := ss.Post().Get(post.Id, false) + postList, err := ss.Post().Get(post.Id, false, false, false) require.Nil(t, err) assert.True(t, postList.Posts[post.Id].HasReactions, "post should still have reactions") - postList, err = ss.Post().Get(post2.Id, false) + postList, err = ss.Post().Get(post2.Id, false, false, false) require.Nil(t, err) assert.True(t, postList.Posts[post2.Id].HasReactions, "post should still have reactions") - postList, err = ss.Post().Get(post3.Id, false) + postList, err = ss.Post().Get(post3.Id, false, false, false) require.Nil(t, err) assert.False(t, postList.Posts[post3.Id].HasReactions, "post shouldn't have reactions any more") diff --git a/store/storetest/thread_store.go b/store/storetest/thread_store.go index daec4583f3..9179b96744 100644 --- a/store/storetest/thread_store.go +++ b/store/storetest/thread_store.go @@ -69,7 +69,7 @@ func testThreadStorePopulation(t *testing.T, ss store.Store) { newPosts, errIdx, err3 := ss.Post().SaveMultiple([]*model.Post{&o2, &o3, &o4}) - olist, _ := ss.Post().Get(otmp.Id, true) + olist, _ := ss.Post().Get(otmp.Id, true, false, false) o1 := olist.Posts[olist.Order[0]] newPosts = append([]*model.Post{o1}, newPosts...) diff --git a/store/timerlayer/timerlayer.go b/store/timerlayer/timerlayer.go index 01e812c05c..d2773461f2 100644 --- a/store/timerlayer/timerlayer.go +++ b/store/timerlayer/timerlayer.go @@ -4444,10 +4444,10 @@ func (s *TimerLayerPostStore) Delete(postId string, time int64, deleteByID strin return err } -func (s *TimerLayerPostStore) Get(id string, skipFetchThreads bool) (*model.PostList, error) { +func (s *TimerLayerPostStore) Get(id string, skipFetchThreads bool, collapsedThreads bool, collapsedThreadsExtended bool) (*model.PostList, error) { start := timemodule.Now() - result, err := s.PostStore.Get(id, skipFetchThreads) + result, err := s.PostStore.Get(id, skipFetchThreads, collapsedThreads, collapsedThreadsExtended) elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) if s.Root.Metrics != nil { @@ -4476,10 +4476,10 @@ func (s *TimerLayerPostStore) GetDirectPostParentsForExportAfter(limit int, afte return result, err } -func (s *TimerLayerPostStore) GetEtag(channelId string, allowFromCache bool) string { +func (s *TimerLayerPostStore) GetEtag(channelId string, allowFromCache bool, collapsedThreads bool) string { start := timemodule.Now() - result := s.PostStore.GetEtag(channelId, allowFromCache) + result := s.PostStore.GetEtag(channelId, allowFromCache, collapsedThreads) elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second) if s.Root.Metrics != nil {