Migrate "Post.GetPostsByIds" to Sync by default (#11034)

This commit is contained in:
GianOrtiz
2019-06-04 18:25:16 -03:00
committed by Jesse Hallam
parent a57e042dab
commit b26e5d444e
6 changed files with 39 additions and 30 deletions

View File

@@ -260,10 +260,9 @@ func AssertFileIdsInPost(files []*model.FileInfo, th *TestHelper, t *testing.T)
postId := files[0].PostId
assert.NotNil(t, postId)
if result := <-th.App.Srv.Store.Post().GetPostsByIds([]string{postId}); result.Err != nil {
t.Fatal(result.Err.Error())
if posts, err := th.App.Srv.Store.Post().GetPostsByIds([]string{postId}); err != nil {
t.Fatal(err.Error())
} else {
posts := result.Data.([]*model.Post)
assert.Equal(t, len(posts), 1)
for _, file := range files {
assert.Contains(t, posts[0].FileIds, file.Id)

View File

@@ -883,11 +883,11 @@ func (a *App) SearchPostsInTeamForUser(terms string, userId string, teamId strin
// Get the posts
postList := model.NewPostList()
if len(postIds) > 0 {
presult := <-a.Srv.Store.Post().GetPostsByIds(postIds)
if presult.Err != nil {
return nil, presult.Err
posts, err := a.Srv.Store.Post().GetPostsByIds(postIds)
if err != nil {
return nil, err
}
for _, p := range presult.Data.([]*model.Post) {
for _, p := range posts {
if p.DeleteAt == 0 {
postList.AddPost(p)
postList.AddOrder(p.Id)

View File

@@ -1105,22 +1105,19 @@ func (s *SqlPostStore) GetPostsCreatedAt(channelId string, time int64) ([]*model
return posts, nil
}
func (s *SqlPostStore) GetPostsByIds(postIds []string) store.StoreChannel {
return store.Do(func(result *store.StoreResult) {
keys, params := MapStringsToQueryParams(postIds, "Post")
func (s *SqlPostStore) GetPostsByIds(postIds []string) ([]*model.Post, *model.AppError) {
keys, params := MapStringsToQueryParams(postIds, "Post")
query := `SELECT * FROM Posts WHERE Id IN ` + keys + ` ORDER BY CreateAt DESC`
query := `SELECT * FROM Posts WHERE Id IN ` + keys + ` ORDER BY CreateAt DESC`
var posts []*model.Post
_, err := s.GetReplica().Select(&posts, query, params)
var posts []*model.Post
_, err := s.GetReplica().Select(&posts, query, params)
if err != nil {
mlog.Error(fmt.Sprint(err))
result.Err = model.NewAppError("SqlPostStore.GetPostsByIds", "store.sql_post.get_posts_by_ids.app_error", nil, "", http.StatusInternalServerError)
} else {
result.Data = posts
}
})
if err != nil {
mlog.Error(fmt.Sprint(err))
return nil, model.NewAppError("SqlPostStore.GetPostsByIds", "store.sql_post.get_posts_by_ids.app_error", nil, "", http.StatusInternalServerError)
}
return posts, nil
}
func (s *SqlPostStore) GetPostsBatchForIndexing(startTime int64, endTime int64, limit int) store.StoreChannel {

View File

@@ -233,7 +233,7 @@ type PostStore interface {
InvalidateLastPostTimeCache(channelId string)
GetPostsCreatedAt(channelId string, time int64) ([]*model.Post, *model.AppError)
Overwrite(post *model.Post) (*model.Post, *model.AppError)
GetPostsByIds(postIds []string) StoreChannel
GetPostsByIds(postIds []string) ([]*model.Post, *model.AppError)
GetPostsBatchForIndexing(startTime int64, endTime int64, limit int) StoreChannel
PermanentDeleteBatch(endTime int64, limit int64) StoreChannel
GetOldest() StoreChannel

View File

@@ -350,19 +350,28 @@ func (_m *PostStore) GetPostsBefore(channelId string, postId string, numPosts in
}
// GetPostsByIds provides a mock function with given fields: postIds
func (_m *PostStore) GetPostsByIds(postIds []string) store.StoreChannel {
func (_m *PostStore) GetPostsByIds(postIds []string) ([]*model.Post, *model.AppError) {
ret := _m.Called(postIds)
var r0 store.StoreChannel
if rf, ok := ret.Get(0).(func([]string) store.StoreChannel); ok {
var r0 []*model.Post
if rf, ok := ret.Get(0).(func([]string) []*model.Post); ok {
r0 = rf(postIds)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(store.StoreChannel)
r0 = ret.Get(0).([]*model.Post)
}
}
return r0
var r1 *model.AppError
if rf, ok := ret.Get(1).(func([]string) *model.AppError); ok {
r1 = rf(postIds)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(*model.AppError)
}
}
return r0, r1
}
// GetPostsCreatedAt provides a mock function with given fields: channelId, time

View File

@@ -1935,16 +1935,20 @@ func testPostStoreGetPostsByIds(t *testing.T, ss store.Store) {
ro3.Id,
}
if ro4 := store.Must(ss.Post().GetPostsByIds(postIds)).([]*model.Post); len(ro4) != 3 {
t.Fatalf("Expected 3 posts in results. Got %v", len(ro4))
if posts, err := ss.Post().GetPostsByIds(postIds); err != nil {
t.Fatal(err)
} else if len(posts) != 3 {
t.Fatalf("Expected 3 posts in results. Got %v", len(posts))
}
if err := ss.Post().Delete(ro1.Id, model.GetMillis(), ""); err != nil {
t.Fatal(err)
}
if ro5 := store.Must(ss.Post().GetPostsByIds(postIds)).([]*model.Post); len(ro5) != 3 {
t.Fatalf("Expected 3 posts in results. Got %v", len(ro5))
if posts, err := ss.Post().GetPostsByIds(postIds); err != nil {
t.Fatal(err)
} else if len(posts) != 3 {
t.Fatalf("Expected 3 posts in results. Got %v", len(posts))
}
}