diff --git a/app/file.go b/app/file.go index 764d117a42..a2157eb53b 100644 --- a/app/file.go +++ b/app/file.go @@ -275,7 +275,8 @@ func (a *App) MigrateFilenamesToFileInfos(post *model.Post) []*model.FileInfo { if newPost := result.Posts[post.Id]; len(newPost.Filenames) != len(post.Filenames) { // Another thread has already created FileInfos for this post, so just return those - fileInfos, err := a.Srv.Store.FileInfo().GetForPost(post.Id, true, false) + var fileInfos []*model.FileInfo + fileInfos, err = a.Srv.Store.FileInfo().GetForPost(post.Id, true, false) if err != nil { mlog.Error(fmt.Sprintf("Unable to get FileInfos for migrated post, err=%v", err), mlog.String("post_id", post.Id)) return []*model.FileInfo{} @@ -290,7 +291,7 @@ func (a *App) MigrateFilenamesToFileInfos(post *model.Post) []*model.FileInfo { savedInfos := make([]*model.FileInfo, 0, len(infos)) fileIds := make([]string, 0, len(filenames)) for _, info := range infos { - if _, err := a.Srv.Store.FileInfo().Save(info); err != nil { + if _, err = a.Srv.Store.FileInfo().Save(info); err != nil { mlog.Error( fmt.Sprintf("Unable to save file info when migrating post to use FileInfos, err=%v", err), mlog.String("post_id", post.Id), @@ -312,8 +313,8 @@ func (a *App) MigrateFilenamesToFileInfos(post *model.Post) []*model.FileInfo { newPost.FileIds = fileIds // Update Posts to clear Filenames and set FileIds - if result := <-a.Srv.Store.Post().Update(newPost, post); result.Err != nil { - mlog.Error(fmt.Sprintf("Unable to save migrated post when migrating to use FileInfos, new_file_ids=%v, old_filenames=%v, err=%v", newPost.FileIds, post.Filenames, result.Err), mlog.String("post_id", post.Id)) + if _, err = a.Srv.Store.Post().Update(newPost, post); err != nil { + mlog.Error(fmt.Sprintf("Unable to save migrated post when migrating to use FileInfos, new_file_ids=%v, old_filenames=%v, err=%v", newPost.FileIds, post.Filenames, err), mlog.String("post_id", post.Id)) return []*model.FileInfo{} } return savedInfos diff --git a/app/post.go b/app/post.go index 88d949f401..199ce06522 100644 --- a/app/post.go +++ b/app/post.go @@ -492,8 +492,7 @@ func (a *App) UpdatePost(post *model.Post, safeUpdate bool) (*model.Post, *model } if channel.DeleteAt != 0 { - err := model.NewAppError("UpdatePost", "api.post.update_post.can_not_update_post_in_deleted.error", nil, "", http.StatusBadRequest) - return nil, err + return nil, model.NewAppError("UpdatePost", "api.post.update_post.can_not_update_post_in_deleted.error", nil, "", http.StatusBadRequest) } newPost := &model.Post{} @@ -517,7 +516,7 @@ func (a *App) UpdatePost(post *model.Post, safeUpdate bool) (*model.Post, *model newPost.EditAt = model.GetMillis() } - if err := a.FillInPostProps(post, nil); err != nil { + if err = a.FillInPostProps(post, nil); err != nil { return nil, err } @@ -533,11 +532,10 @@ func (a *App) UpdatePost(post *model.Post, safeUpdate bool) (*model.Post, *model } } - result := <-a.Srv.Store.Post().Update(newPost, oldPost) - if result.Err != nil { - return nil, result.Err + rpost, err := a.Srv.Store.Post().Update(newPost, oldPost) + if err != nil { + return nil, err } - rpost := result.Data.(*model.Post) if pluginsEnvironment := a.GetPluginsEnvironment(); pluginsEnvironment != nil { a.Srv.Go(func() { diff --git a/store/sqlstore/post_store.go b/store/sqlstore/post_store.go index bc02ea7996..d7d213dceb 100644 --- a/store/sqlstore/post_store.go +++ b/store/sqlstore/post_store.go @@ -129,39 +129,37 @@ func (s *SqlPostStore) Save(post *model.Post) store.StoreChannel { }) } -func (s *SqlPostStore) Update(newPost *model.Post, oldPost *model.Post) store.StoreChannel { - return store.Do(func(result *store.StoreResult) { - newPost.UpdateAt = model.GetMillis() - newPost.PreCommit() +func (s *SqlPostStore) Update(newPost *model.Post, oldPost *model.Post) (*model.Post, *model.AppError) { + newPost.UpdateAt = model.GetMillis() + newPost.PreCommit() - oldPost.DeleteAt = newPost.UpdateAt - oldPost.UpdateAt = newPost.UpdateAt - oldPost.OriginalId = oldPost.Id - oldPost.Id = model.NewId() - oldPost.PreCommit() + oldPost.DeleteAt = newPost.UpdateAt + oldPost.UpdateAt = newPost.UpdateAt + oldPost.OriginalId = oldPost.Id + oldPost.Id = model.NewId() + oldPost.PreCommit() - maxPostSize := s.GetMaxPostSize() + maxPostSize := s.GetMaxPostSize() - if result.Err = newPost.IsValid(maxPostSize); result.Err != nil { - return - } + if err := newPost.IsValid(maxPostSize); err != nil { + return nil, err + } - if _, err := s.GetMaster().Update(newPost); err != nil { - result.Err = model.NewAppError("SqlPostStore.Update", "store.sql_post.update.app_error", nil, "id="+newPost.Id+", "+err.Error(), http.StatusInternalServerError) - } else { - time := model.GetMillis() - s.GetMaster().Exec("UPDATE Channels SET LastPostAt = :LastPostAt WHERE Id = :ChannelId AND LastPostAt < :LastPostAt", map[string]interface{}{"LastPostAt": time, "ChannelId": newPost.ChannelId}) + if _, err := s.GetMaster().Update(newPost); err != nil { + return nil, model.NewAppError("SqlPostStore.Update", "store.sql_post.update.app_error", nil, "id="+newPost.Id+", "+err.Error(), http.StatusInternalServerError) + } - if len(newPost.RootId) > 0 { - s.GetMaster().Exec("UPDATE Posts SET UpdateAt = :UpdateAt WHERE Id = :RootId AND UpdateAt < :UpdateAt", map[string]interface{}{"UpdateAt": time, "RootId": newPost.RootId}) - } + time := model.GetMillis() + s.GetMaster().Exec("UPDATE Channels SET LastPostAt = :LastPostAt WHERE Id = :ChannelId AND LastPostAt < :LastPostAt", map[string]interface{}{"LastPostAt": time, "ChannelId": newPost.ChannelId}) - // mark the old post as deleted - s.GetMaster().Insert(oldPost) + if len(newPost.RootId) > 0 { + s.GetMaster().Exec("UPDATE Posts SET UpdateAt = :UpdateAt WHERE Id = :RootId AND UpdateAt < :UpdateAt", map[string]interface{}{"UpdateAt": time, "RootId": newPost.RootId}) + } - result.Data = newPost - } - }) + // mark the old post as deleted + s.GetMaster().Insert(oldPost) + + return newPost, nil } func (s *SqlPostStore) Overwrite(post *model.Post) (*model.Post, *model.AppError) { diff --git a/store/store.go b/store/store.go index c0c190b0a0..97917fe87c 100644 --- a/store/store.go +++ b/store/store.go @@ -211,7 +211,7 @@ type ChannelMemberHistoryStore interface { type PostStore interface { Save(post *model.Post) StoreChannel - Update(newPost *model.Post, oldPost *model.Post) StoreChannel + Update(newPost *model.Post, oldPost *model.Post) (*model.Post, *model.AppError) Get(id string) (*model.PostList, *model.AppError) GetSingle(id string) StoreChannel Delete(postId string, time int64, deleteByID string) *model.AppError diff --git a/store/storetest/mocks/PostStore.go b/store/storetest/mocks/PostStore.go index b6dd1324bd..d464188dd0 100644 --- a/store/storetest/mocks/PostStore.go +++ b/store/storetest/mocks/PostStore.go @@ -506,17 +506,26 @@ func (_m *PostStore) Search(teamId string, userId string, params *model.SearchPa } // Update provides a mock function with given fields: newPost, oldPost -func (_m *PostStore) Update(newPost *model.Post, oldPost *model.Post) store.StoreChannel { +func (_m *PostStore) Update(newPost *model.Post, oldPost *model.Post) (*model.Post, *model.AppError) { ret := _m.Called(newPost, oldPost) - var r0 store.StoreChannel - if rf, ok := ret.Get(0).(func(*model.Post, *model.Post) store.StoreChannel); ok { + var r0 *model.Post + if rf, ok := ret.Get(0).(func(*model.Post, *model.Post) *model.Post); ok { r0 = rf(newPost, oldPost) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(store.StoreChannel) + r0 = ret.Get(0).(*model.Post) } } - return r0 + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(*model.Post, *model.Post) *model.AppError); ok { + r1 = rf(newPost, oldPost) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 } diff --git a/store/storetest/post_store.go b/store/storetest/post_store.go index c989e2bd45..b25822629b 100644 --- a/store/storetest/post_store.go +++ b/store/storetest/post_store.go @@ -244,8 +244,8 @@ func testPostStoreUpdate(t *testing.T, ss store.Store) { o1a := &model.Post{} *o1a = *ro1 o1a.Message = ro1.Message + "BBBBBBBBBB" - if result := <-ss.Post().Update(o1a, ro1); result.Err != nil { - t.Fatal(result.Err) + if _, err = ss.Post().Update(o1a, ro1); err != nil { + t.Fatal(err) } r1, err = ss.Post().Get(o1.Id) @@ -261,8 +261,8 @@ func testPostStoreUpdate(t *testing.T, ss store.Store) { o2a := &model.Post{} *o2a = *ro2 o2a.Message = ro2.Message + "DDDDDDD" - if result := <-ss.Post().Update(o2a, ro2); result.Err != nil { - t.Fatal(result.Err) + if _, err = ss.Post().Update(o2a, ro2); err != nil { + t.Fatal(err) } r2, err = ss.Post().Get(o1.Id) @@ -278,8 +278,8 @@ func testPostStoreUpdate(t *testing.T, ss store.Store) { o3a := &model.Post{} *o3a = *ro3 o3a.Message = ro3.Message + "WWWWWWW" - if result := <-ss.Post().Update(o3a, ro3); result.Err != nil { - t.Fatal(result.Err) + if _, err = ss.Post().Update(o3a, ro3); err != nil { + t.Fatal(err) } r3, err = ss.Post().Get(o3.Id) @@ -309,8 +309,8 @@ func testPostStoreUpdate(t *testing.T, ss store.Store) { *o4a = *ro4 o4a.Filenames = []string{} o4a.FileIds = []string{model.NewId()} - if result := <-ss.Post().Update(o4a, ro4); result.Err != nil { - t.Fatal(result.Err) + if _, err = ss.Post().Update(o4a, ro4); err != nil { + t.Fatal(err) } r4, err = ss.Post().Get(o4.Id) @@ -2288,8 +2288,8 @@ func testPostStoreGetDirectPostParentsForExportAfterDeleted(t *testing.T, ss sto *o1a = *p1 o1a.DeleteAt = 1 o1a.Message = p1.Message + "BBBBBBBBBB" - if result := <-ss.Post().Update(o1a, p1); result.Err != nil { - t.Fatal(result.Err) + if _, err := ss.Post().Update(o1a, p1); err != nil { + t.Fatal(err) } r1 := <-ss.Post().GetDirectPostParentsForExportAfter(10000, strings.Repeat("0", 26))