From 4d2ed469bf707b5972dd8f63fe2ee9d0128a92e1 Mon Sep 17 00:00:00 2001 From: Kyriakos Z <3829551+koox00@users.noreply.github.com> Date: Sat, 1 Apr 2023 10:32:40 +0300 Subject: [PATCH] MM-49564: Drafts upsert in the Store vs App layer (#22530) * MM-49564: Upsert in the Store vs App layer Refactor drafts so that Upserting a draft would happen in the DB and not in the app layer. * Fixes mocks * Fixes tests * Fixes translations * Fixes tests * Update tests * Fixes tests * Addresses review comments - renames Save => Upsert - removes Sleep from tests * Fixes flaky test --------- Co-authored-by: Mattermost Build --- model/draft.go | 9 +- model/draft_test.go | 13 -- server/channels/api4/drafts.go | 4 +- server/channels/app/app_iface.go | 2 - server/channels/app/draft.go | 69 +--------- server/channels/app/draft_test.go | 123 ++++++------------ .../app/opentracing/opentracing_layer.go | 44 ------- .../opentracinglayer/opentracinglayer.go | 24 +--- .../channels/store/retrylayer/retrylayer.go | 25 +--- server/channels/store/sqlstore/draft_store.go | 40 ++---- server/channels/store/store.go | 3 +- .../channels/store/storetest/draft_store.go | 99 +++++++------- .../store/storetest/mocks/DraftStore.go | 27 +--- .../channels/store/timerlayer/timerlayer.go | 22 +--- server/i18n/en.json | 4 - 15 files changed, 113 insertions(+), 395 deletions(-) diff --git a/model/draft.go b/model/draft.go index a9741e5727..73d1e69f98 100644 --- a/model/draft.go +++ b/model/draft.go @@ -81,9 +81,11 @@ func (o *Draft) GetProps() StringInterface { func (o *Draft) PreSave() { if o.CreateAt == 0 { o.CreateAt = GetMillis() + o.UpdateAt = o.CreateAt + } else { + o.UpdateAt = GetMillis() } - o.UpdateAt = o.CreateAt o.DeleteAt = 0 o.PreCommit() } @@ -100,8 +102,3 @@ func (o *Draft) PreCommit() { // There's a rare bug where the client sends up duplicate FileIds so protect against that o.FileIds = RemoveDuplicateStrings(o.FileIds) } - -func (o *Draft) PreUpdate() { - o.UpdateAt = GetMillis() - o.PreCommit() -} diff --git a/model/draft_test.go b/model/draft_test.go index 2e931c31dd..dec7e56a17 100644 --- a/model/draft_test.go +++ b/model/draft_test.go @@ -65,16 +65,3 @@ func TestDraftPreSave(t *testing.T) { assert.LessOrEqual(t, o.CreateAt, past) } - -func TestDraftPreUpdate(t *testing.T) { - o := Draft{Message: "test"} - o.PreUpdate() - - assert.NotEqual(t, 0, o.UpdateAt) - - past := GetMillis() - 1 - o = Draft{Message: "test", UpdateAt: past} - o.PreSave() - - assert.GreaterOrEqual(t, o.UpdateAt, past) -} diff --git a/server/channels/api4/drafts.go b/server/channels/api4/drafts.go index 2b11bf258a..bfd1112f94 100644 --- a/server/channels/api4/drafts.go +++ b/server/channels/api4/drafts.go @@ -62,7 +62,7 @@ func upsertDraft(c *Context, w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusCreated) if err := json.NewEncoder(w).Encode(dt); err != nil { - mlog.Warn("Error while writing response", mlog.Err(err)) + c.Logger.Warn("Error while writing response", mlog.Err(err)) } } @@ -94,7 +94,7 @@ func getDrafts(c *Context, w http.ResponseWriter, r *http.Request) { } if err := json.NewEncoder(w).Encode(drafts); err != nil { - mlog.Warn("Error while writing response", mlog.Err(err)) + c.Logger.Warn("Error while writing response", mlog.Err(err)) } } diff --git a/server/channels/app/app_iface.go b/server/channels/app/app_iface.go index cf749c3fc9..eb30bc0779 100644 --- a/server/channels/app/app_iface.go +++ b/server/channels/app/app_iface.go @@ -482,7 +482,6 @@ type AppIface interface { CreateChannelWithUser(c request.CTX, channel *model.Channel, userID string) (*model.Channel, *model.AppError) CreateCommand(cmd *model.Command) (*model.Command, *model.AppError) CreateCommandWebhook(commandID string, args *model.CommandArgs) (*model.CommandWebhook, *model.AppError) - CreateDraft(c *request.Context, draft *model.Draft, connectionID string) (*model.Draft, *model.AppError) CreateEmoji(c request.CTX, sessionUserId string, emoji *model.Emoji, multiPartImageData *multipart.Form) (*model.Emoji, *model.AppError) CreateGroup(group *model.Group) (*model.Group, *model.AppError) CreateGroupChannel(c request.CTX, userIDs []string, creatorId string) (*model.Channel, *model.AppError) @@ -1121,7 +1120,6 @@ type AppIface interface { UpdateChannelPrivacy(c request.CTX, oldChannel *model.Channel, user *model.User) (*model.Channel, *model.AppError) UpdateCommand(oldCmd, updatedCmd *model.Command) (*model.Command, *model.AppError) UpdateConfig(f func(*model.Config)) - UpdateDraft(c *request.Context, draft *model.Draft, connectionID string) (*model.Draft, *model.AppError) UpdateEphemeralPost(c request.CTX, userID string, post *model.Post) *model.Post UpdateExpiredDNDStatuses() ([]*model.Status, error) UpdateGroup(group *model.Group) (*model.Group, *model.AppError) diff --git a/server/channels/app/draft.go b/server/channels/app/draft.go index 3521adea71..a8ca1ef9d2 100644 --- a/server/channels/app/draft.go +++ b/server/channels/app/draft.go @@ -35,33 +35,6 @@ func (a *App) GetDraft(userID, channelID, rootID string) (*model.Draft, *model.A } func (a *App) UpsertDraft(c *request.Context, draft *model.Draft, connectionID string) (*model.Draft, *model.AppError) { - if !a.Config().FeatureFlags.GlobalDrafts || !*a.Config().ServiceSettings.AllowSyncedDrafts { - return nil, model.NewAppError("UpsertDraft", "app.draft.feature_disabled", nil, "", http.StatusNotImplemented) - } - - dt, dErr := a.Srv().Store().Draft().Get(draft.UserId, draft.ChannelId, draft.RootId, true) - var notFoundErr *store.ErrNotFound - if dErr != nil && !errors.As(dErr, ¬FoundErr) { - return nil, model.NewAppError("UpsertDraft", "app.select_error", nil, dErr.Error(), http.StatusInternalServerError) - } - - var err *model.AppError - if dt == nil { - dt, err = a.CreateDraft(c, draft, connectionID) - if err != nil { - return nil, err - } - } else { - dt, err = a.UpdateDraft(c, draft, connectionID) - if err != nil { - return nil, err - } - } - - return dt, nil -} - -func (a *App) CreateDraft(c *request.Context, draft *model.Draft, connectionID string) (*model.Draft, *model.AppError) { if !a.Config().FeatureFlags.GlobalDrafts || !*a.Config().ServiceSettings.AllowSyncedDrafts { return nil, model.NewAppError("CreateDraft", "app.draft.feature_disabled", nil, "", http.StatusNotImplemented) } @@ -83,7 +56,7 @@ func (a *App) CreateDraft(c *request.Context, draft *model.Draft, connectionID s return nil, model.NewAppError("CreateDraft", "app.user.get.app_error", nil, nErr.Error(), http.StatusInternalServerError) } - dt, nErr := a.Srv().Store().Draft().Save(draft) + dt, nErr := a.Srv().Store().Draft().Upsert(draft) if nErr != nil { return nil, model.NewAppError("CreateDraft", "app.draft.save.app_error", nil, nErr.Error(), http.StatusInternalServerError) } @@ -101,46 +74,6 @@ func (a *App) CreateDraft(c *request.Context, draft *model.Draft, connectionID s return dt, nil } -func (a *App) UpdateDraft(c *request.Context, draft *model.Draft, connectionID string) (*model.Draft, *model.AppError) { - if !a.Config().FeatureFlags.GlobalDrafts { - return nil, model.NewAppError("UpsertDraft", "app.draft.feature_disabled", nil, "", http.StatusNotImplemented) - } - - // Check that channel exists and has not been deleted - channel, errCh := a.Srv().Store().Channel().Get(draft.ChannelId, true) - if errCh != nil { - err := model.NewAppError("UpdateDraft", "api.context.invalid_param.app_error", map[string]interface{}{"Name": "draft.channel_id"}, errCh.Error(), http.StatusBadRequest) - return nil, err - } - - if channel.DeleteAt != 0 { - err := model.NewAppError("UpdateDraft", "api.draft.create_draft.can_not_draft_to_deleted.error", nil, "", http.StatusBadRequest) - return nil, err - } - - _, nErr := a.Srv().Store().User().Get(context.Background(), draft.UserId) - if nErr != nil { - return nil, model.NewAppError("UpdateDraft", "app.user.get.app_error", nil, nErr.Error(), http.StatusInternalServerError) - } - - dt, nErr := a.Srv().Store().Draft().Update(draft) - if nErr != nil { - return nil, model.NewAppError("UpdateDraft", "app.draft.update.app_error", nil, nErr.Error(), http.StatusInternalServerError) - } - - dt = a.prepareDraftWithFileInfos(draft.UserId, dt) - - message := model.NewWebSocketEvent(model.WebsocketEventDraftUpdated, "", draft.ChannelId, draft.UserId, nil, connectionID) - draftJSON, jsonErr := json.Marshal(dt) - if jsonErr != nil { - mlog.Warn("Failed to encode draft to JSON", mlog.Err(jsonErr)) - } - message.Add("draft", string(draftJSON)) - a.Publish(message) - - return dt, nil -} - func (a *App) GetDraftsForUser(userID, teamID string) ([]*model.Draft, *model.AppError) { if !a.Config().FeatureFlags.GlobalDrafts || !*a.Config().ServiceSettings.AllowSyncedDrafts { return nil, model.NewAppError("GetDraftsForUser", "app.draft.feature_disabled", nil, "", http.StatusNotImplemented) diff --git a/server/channels/app/draft_test.go b/server/channels/app/draft_test.go index 287ce82430..4b52efd8e9 100644 --- a/server/channels/app/draft_test.go +++ b/server/channels/app/draft_test.go @@ -81,34 +81,41 @@ func TestUpsertDraft(t *testing.T) { user := th.BasicUser channel := th.BasicChannel - draft1 := &model.Draft{ - CreateAt: 00001, - UpdateAt: 00001, + draft := &model.Draft{ UserId: user.Id, ChannelId: channel.Id, - Message: "draft1", + Message: "draft", } - draft2 := &model.Draft{ - CreateAt: 00001, - UpdateAt: 00002, - UserId: user.Id, - ChannelId: channel.Id, - Message: "draft2", - } - - _, createDraftErr := th.App.CreateDraft(th.Context, draft1, "") - assert.Nil(t, createDraftErr) - t.Run("upsert draft", func(t *testing.T) { - draftResp, err := th.App.UpsertDraft(th.Context, draft2, "") + _, err := th.App.UpsertDraft(th.Context, draft, "") assert.Nil(t, err) - assert.Equal(t, draft2.Message, draftResp.Message) - assert.Equal(t, draft2.ChannelId, draftResp.ChannelId) - assert.Equal(t, draft2.CreateAt, draftResp.CreateAt) + drafts, err := th.App.GetDraftsForUser(user.Id, th.BasicTeam.Id) + assert.Nil(t, err) + assert.Len(t, drafts, 1) + draft1 := drafts[0] - assert.NotEqual(t, draft1.UpdateAt, draftResp.UpdateAt) + assert.Equal(t, "draft", draft1.Message) + assert.Equal(t, channel.Id, draft1.ChannelId) + assert.Greater(t, draft1.CreateAt, int64(0)) + + draft = &model.Draft{ + UserId: user.Id, + ChannelId: channel.Id, + Message: "updated draft", + } + _, err = th.App.UpsertDraft(th.Context, draft, "") + assert.Nil(t, err) + + drafts, err = th.App.GetDraftsForUser(user.Id, th.BasicTeam.Id) + assert.Nil(t, err) + assert.Len(t, drafts, 1) + draft2 := drafts[0] + + assert.Equal(t, "updated draft", draft2.Message) + assert.Equal(t, channel.Id, draft2.ChannelId) + assert.Equal(t, draft1.CreateAt, draft2.CreateAt) }) t.Run("upsert draft feature flag", func(t *testing.T) { @@ -123,7 +130,7 @@ func TestUpsertDraft(t *testing.T) { defer th.App.UpdateConfig(func(cfg *model.Config) { cfg.FeatureFlags.GlobalDrafts = true }) defer th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.AllowSyncedDrafts = true }) - _, err := th.App.UpsertDraft(th.Context, draft1, "") + _, err := th.App.UpsertDraft(th.Context, draft, "") assert.NotNil(t, err) }) } @@ -160,7 +167,7 @@ func TestCreateDraft(t *testing.T) { } t.Run("create draft", func(t *testing.T) { - draftResp, err := th.App.CreateDraft(th.Context, draft1, "") + draftResp, err := th.App.UpsertDraft(th.Context, draft1, "") assert.Nil(t, err) assert.Equal(t, draft1.Message, draftResp.Message) @@ -178,29 +185,13 @@ func TestCreateDraft(t *testing.T) { draftWithFiles := draft2 draftWithFiles.FileIds = []string{fileResp.Id} - draftResp, err := th.App.CreateDraft(th.Context, draftWithFiles, "") + draftResp, err := th.App.UpsertDraft(th.Context, draftWithFiles, "") assert.Nil(t, err) assert.Equal(t, draftWithFiles.Message, draftResp.Message) assert.Equal(t, draftWithFiles.ChannelId, draftResp.ChannelId) assert.ElementsMatch(t, draftWithFiles.FileIds, draftResp.FileIds) }) - - t.Run("create draft feature flag", func(t *testing.T) { - os.Setenv("MM_FEATUREFLAGS_GLOBALDRAFTS", "false") - defer os.Unsetenv("MM_FEATUREFLAGS_GLOBALDRAFTS") - os.Setenv("MM_SERVICESETTINGS_ALLOWSYNCEDDRAFTS", "false") - defer os.Unsetenv("MM_SERVICESETTINGS_ALLOWSYNCEDDRAFTS") - - th.App.UpdateConfig(func(cfg *model.Config) { cfg.FeatureFlags.GlobalDrafts = false }) - th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.AllowSyncedDrafts = false }) - - defer th.App.UpdateConfig(func(cfg *model.Config) { cfg.FeatureFlags.GlobalDrafts = true }) - defer th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.AllowSyncedDrafts = true }) - - _, err := th.App.CreateDraft(th.Context, draft1, "") - assert.NotNil(t, err) - }) } func TestUpdateDraft(t *testing.T) { @@ -217,34 +208,14 @@ func TestUpdateDraft(t *testing.T) { channel := th.BasicChannel draft1 := &model.Draft{ - CreateAt: 00001, - UpdateAt: 00001, UserId: user.Id, ChannelId: channel.Id, Message: "draft1", } - draft2 := &model.Draft{ - CreateAt: 00001, - UpdateAt: 00002, - UserId: user.Id, - ChannelId: channel.Id, - Message: "draft2", - } - - _, createDraftErr := th.App.CreateDraft(th.Context, draft1, "") + _, createDraftErr := th.App.UpsertDraft(th.Context, draft1, "") assert.Nil(t, createDraftErr) - t.Run("update draft", func(t *testing.T) { - draftResp, err := th.App.UpdateDraft(th.Context, draft2, "") - assert.Nil(t, err) - - assert.Equal(t, draft2.Message, draftResp.Message) - assert.Equal(t, draft2.ChannelId, draftResp.ChannelId) - - assert.NotEqual(t, draft1.UpdateAt, draftResp.UpdateAt) - }) - t.Run("update draft with files", func(t *testing.T) { // upload file sent, readFileErr := testutils.ReadTestFile("test.png") @@ -256,29 +227,17 @@ func TestUpdateDraft(t *testing.T) { draftWithFiles := draft1 draftWithFiles.FileIds = []string{fileResp.Id} - draftResp, err := th.App.UpdateDraft(th.Context, draft1, "") + _, err := th.App.UpsertDraft(th.Context, draft1, "") assert.Nil(t, err) + drafts, err := th.App.GetDraftsForUser(user.Id, th.BasicTeam.Id) + assert.Nil(t, err) + + draftResp := drafts[0] assert.Equal(t, draftWithFiles.Message, draftResp.Message) assert.Equal(t, draftWithFiles.ChannelId, draftResp.ChannelId) assert.ElementsMatch(t, draftWithFiles.FileIds, draftResp.FileIds) }) - - t.Run("create draft feature flag", func(t *testing.T) { - os.Setenv("MM_FEATUREFLAGS_GLOBALDRAFTS", "false") - defer os.Unsetenv("MM_FEATUREFLAGS_GLOBALDRAFTS") - os.Setenv("MM_SERVICESETTINGS_ALLOWSYNCEDDRAFTS", "false") - defer os.Unsetenv("MM_SERVICESETTINGS_ALLOWSYNCEDDRAFTS") - - th.App.UpdateConfig(func(cfg *model.Config) { cfg.FeatureFlags.GlobalDrafts = false }) - th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.AllowSyncedDrafts = false }) - - defer th.App.UpdateConfig(func(cfg *model.Config) { cfg.FeatureFlags.GlobalDrafts = true }) - defer th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.AllowSyncedDrafts = true }) - - _, err := th.App.UpdateDraft(th.Context, draft1, "") - assert.NotNil(t, err) - }) } func TestGetDraftsForUser(t *testing.T) { @@ -312,10 +271,10 @@ func TestGetDraftsForUser(t *testing.T) { Message: "draft2", } - _, createDraftErr1 := th.App.CreateDraft(th.Context, draft1, "") + _, createDraftErr1 := th.App.UpsertDraft(th.Context, draft1, "") assert.Nil(t, createDraftErr1) - _, createDraftErr2 := th.App.CreateDraft(th.Context, draft2, "") + _, createDraftErr2 := th.App.UpsertDraft(th.Context, draft2, "") assert.Nil(t, createDraftErr2) t.Run("get drafts", func(t *testing.T) { @@ -340,7 +299,7 @@ func TestGetDraftsForUser(t *testing.T) { draftWithFiles := draft1 draftWithFiles.FileIds = []string{fileResp.Id} - draftResp, updateDraftErr := th.App.UpdateDraft(th.Context, draft1, "") + draftResp, updateDraftErr := th.App.UpsertDraft(th.Context, draft1, "") assert.Nil(t, updateDraftErr) assert.Equal(t, draftWithFiles.Message, draftResp.Message) @@ -397,7 +356,7 @@ func TestDeleteDraft(t *testing.T) { Message: "draft1", } - _, createDraftErr := th.App.CreateDraft(th.Context, draft1, "") + _, createDraftErr := th.App.UpsertDraft(th.Context, draft1, "") assert.Nil(t, createDraftErr) t.Run("delete draft", func(t *testing.T) { @@ -411,7 +370,7 @@ func TestDeleteDraft(t *testing.T) { assert.Equal(t, draft1.ChannelId, draftResp.ChannelId) }) - t.Run("get drafts feature flag", func(t *testing.T) { + t.Run("delete drafts feature flag", func(t *testing.T) { os.Setenv("MM_FEATUREFLAGS_GLOBALDRAFTS", "false") defer os.Unsetenv("MM_FEATUREFLAGS_GLOBALDRAFTS") os.Setenv("MM_SERVICESETTINGS_ALLOWSYNCEDDRAFTS", "false") diff --git a/server/channels/app/opentracing/opentracing_layer.go b/server/channels/app/opentracing/opentracing_layer.go index 4a4cd72441..6e522bb971 100644 --- a/server/channels/app/opentracing/opentracing_layer.go +++ b/server/channels/app/opentracing/opentracing_layer.go @@ -2026,28 +2026,6 @@ func (a *OpenTracingAppLayer) CreateDefaultMemberships(c *request.Context, param return resultVar0 } -func (a *OpenTracingAppLayer) CreateDraft(c *request.Context, draft *model.Draft, connectionID string) (*model.Draft, *model.AppError) { - origCtx := a.ctx - span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.CreateDraft") - - a.ctx = newCtx - a.app.Srv().Store().SetContext(newCtx) - defer func() { - a.app.Srv().Store().SetContext(origCtx) - a.ctx = origCtx - }() - - defer span.Finish() - resultVar0, resultVar1 := a.app.CreateDraft(c, draft, connectionID) - - if resultVar1 != nil { - span.LogFields(spanlog.Error(resultVar1)) - ext.Error.Set(span, true) - } - - return resultVar0, resultVar1 -} - func (a *OpenTracingAppLayer) CreateEmoji(c request.CTX, sessionUserId string, emoji *model.Emoji, multiPartImageData *multipart.Form) (*model.Emoji, *model.AppError) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.CreateEmoji") @@ -17381,28 +17359,6 @@ func (a *OpenTracingAppLayer) UpdateDNDStatusOfUsers() { a.app.UpdateDNDStatusOfUsers() } -func (a *OpenTracingAppLayer) UpdateDraft(c *request.Context, draft *model.Draft, connectionID string) (*model.Draft, *model.AppError) { - origCtx := a.ctx - span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.UpdateDraft") - - a.ctx = newCtx - a.app.Srv().Store().SetContext(newCtx) - defer func() { - a.app.Srv().Store().SetContext(origCtx) - a.ctx = origCtx - }() - - defer span.Finish() - resultVar0, resultVar1 := a.app.UpdateDraft(c, draft, connectionID) - - if resultVar1 != nil { - span.LogFields(spanlog.Error(resultVar1)) - ext.Error.Set(span, true) - } - - return resultVar0, resultVar1 -} - func (a *OpenTracingAppLayer) UpdateEphemeralPost(c request.CTX, userID string, post *model.Post) *model.Post { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.UpdateEphemeralPost") diff --git a/server/channels/store/opentracinglayer/opentracinglayer.go b/server/channels/store/opentracinglayer/opentracinglayer.go index 15d32a18e2..723b8a4775 100644 --- a/server/channels/store/opentracinglayer/opentracinglayer.go +++ b/server/channels/store/opentracinglayer/opentracinglayer.go @@ -3315,34 +3315,16 @@ func (s *OpenTracingLayerDraftStore) GetDraftsForUser(userID string, teamID stri return result, err } -func (s *OpenTracingLayerDraftStore) Save(d *model.Draft) (*model.Draft, error) { +func (s *OpenTracingLayerDraftStore) Upsert(d *model.Draft) (*model.Draft, error) { origCtx := s.Root.Store.Context() - span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "DraftStore.Save") + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "DraftStore.Upsert") s.Root.Store.SetContext(newCtx) defer func() { s.Root.Store.SetContext(origCtx) }() defer span.Finish() - result, err := s.DraftStore.Save(d) - if err != nil { - span.LogFields(spanlog.Error(err)) - ext.Error.Set(span, true) - } - - return result, err -} - -func (s *OpenTracingLayerDraftStore) Update(d *model.Draft) (*model.Draft, error) { - origCtx := s.Root.Store.Context() - span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "DraftStore.Update") - s.Root.Store.SetContext(newCtx) - defer func() { - s.Root.Store.SetContext(origCtx) - }() - - defer span.Finish() - result, err := s.DraftStore.Update(d) + result, err := s.DraftStore.Upsert(d) if err != nil { span.LogFields(spanlog.Error(err)) ext.Error.Set(span, true) diff --git a/server/channels/store/retrylayer/retrylayer.go b/server/channels/store/retrylayer/retrylayer.go index 07997b61ac..3603bd7f7e 100644 --- a/server/channels/store/retrylayer/retrylayer.go +++ b/server/channels/store/retrylayer/retrylayer.go @@ -3703,32 +3703,11 @@ func (s *RetryLayerDraftStore) GetDraftsForUser(userID string, teamID string) ([ } -func (s *RetryLayerDraftStore) Save(d *model.Draft) (*model.Draft, error) { +func (s *RetryLayerDraftStore) Upsert(d *model.Draft) (*model.Draft, error) { tries := 0 for { - result, err := s.DraftStore.Save(d) - if err == nil { - return result, nil - } - if !isRepeatableError(err) { - return result, err - } - tries++ - if tries >= 3 { - err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") - return result, err - } - timepkg.Sleep(100 * timepkg.Millisecond) - } - -} - -func (s *RetryLayerDraftStore) Update(d *model.Draft) (*model.Draft, error) { - - tries := 0 - for { - result, err := s.DraftStore.Update(d) + result, err := s.DraftStore.Upsert(d) if err == nil { return result, nil } diff --git a/server/channels/store/sqlstore/draft_store.go b/server/channels/store/sqlstore/draft_store.go index b5b7dee743..7d4137a8e5 100644 --- a/server/channels/store/sqlstore/draft_store.go +++ b/server/channels/store/sqlstore/draft_store.go @@ -88,7 +88,7 @@ func (s *SqlDraftStore) Get(userId, channelId, rootId string, includeDeleted boo return &dt, nil } -func (s *SqlDraftStore) Save(draft *model.Draft) (*model.Draft, error) { +func (s *SqlDraftStore) Upsert(draft *model.Draft) (*model.Draft, error) { draft.PreSave() maxDraftSize := s.GetMaxDraftSize() if err := draft.IsValid(maxDraftSize); err != nil { @@ -96,6 +96,13 @@ func (s *SqlDraftStore) Save(draft *model.Draft) (*model.Draft, error) { } builder := s.getQueryBuilder().Insert("Drafts").Columns(draftSliceColumns()...).Values(draftToSlice(draft)...) + + if s.DriverName() == model.DatabaseDriverMysql { + builder = builder.SuffixExpr(sq.Expr("ON DUPLICATE KEY UPDATE UpdateAt = ?, Message = ?, Props = ?, FileIds = ?, Priority = ?, DeleteAt = ?", draft.UpdateAt, draft.Message, draft.Props, draft.FileIds, draft.Priority, 0)) + } else { + builder = builder.SuffixExpr(sq.Expr("ON CONFLICT (UserId, ChannelId, RootId) DO UPDATE SET UpdateAt = ?, Message = ?, Props = ?, FileIds = ?, Priority = ?, DeleteAt = ?", draft.UpdateAt, draft.Message, draft.Props, draft.FileIds, draft.Priority, 0)) + } + query, args, err := builder.ToSql() if err != nil { @@ -103,36 +110,7 @@ func (s *SqlDraftStore) Save(draft *model.Draft) (*model.Draft, error) { } if _, err = s.GetMasterX().Exec(query, args...); err != nil { - return nil, errors.Wrap(err, "failed to save Draft") - } - - return draft, nil -} - -func (s *SqlDraftStore) Update(draft *model.Draft) (*model.Draft, error) { - draft.PreUpdate() - - maxDraftSize := s.GetMaxDraftSize() - if err := draft.IsValid(maxDraftSize); err != nil { - return nil, err - } - - query := s.getQueryBuilder(). - Update("Drafts"). - Set("UpdateAt", draft.UpdateAt). - Set("Message", draft.Message). - Set("Props", draft.Props). - Set("FileIds", draft.FileIds). - Set("Priority", draft.Priority). - Set("DeleteAt", 0). - Where(sq.Eq{ - "UserId": draft.UserId, - "ChannelId": draft.ChannelId, - "RootId": draft.RootId, - }) - - if _, err := s.GetMasterX().ExecBuilder(query); err != nil { - return nil, errors.Wrapf(err, "failed to update Draft with channelid=%s", draft.ChannelId) + return nil, errors.Wrap(err, "failed to upsert Draft") } return draft, nil diff --git a/server/channels/store/store.go b/server/channels/store/store.go index dec4fa0f89..e52c2037ab 100644 --- a/server/channels/store/store.go +++ b/server/channels/store/store.go @@ -985,11 +985,10 @@ type PostPriorityStore interface { } type DraftStore interface { - Save(d *model.Draft) (*model.Draft, error) + Upsert(d *model.Draft) (*model.Draft, error) Get(userID, channelID, rootID string, includeDeleted bool) (*model.Draft, error) Delete(userID, channelID, rootID string) error GetDraftsForUser(userID, teamID string) ([]*model.Draft, error) - Update(d *model.Draft) (*model.Draft, error) } type PostAcknowledgementStore interface { diff --git a/server/channels/store/storetest/draft_store.go b/server/channels/store/storetest/draft_store.go index 377f36fa60..45eb33cae9 100644 --- a/server/channels/store/storetest/draft_store.go +++ b/server/channels/store/storetest/draft_store.go @@ -68,17 +68,21 @@ func testSaveDraft(t *testing.T, ss store.Store) { } t.Run("save drafts", func(t *testing.T) { - draftResp, err := ss.Draft().Save(draft1) + draftResp, err := ss.Draft().Upsert(draft1) assert.NoError(t, err) assert.Equal(t, draft1.Message, draftResp.Message) assert.Equal(t, draft1.ChannelId, draftResp.ChannelId) - draftResp, err = ss.Draft().Save(draft2) + draftResp, err = ss.Draft().Upsert(draft2) assert.NoError(t, err) assert.Equal(t, draft2.Message, draftResp.Message) assert.Equal(t, draft2.ChannelId, draftResp.ChannelId) + + drafts, err := ss.Draft().GetDraftsForUser(user.Id, "") + assert.NoError(t, err) + assert.Len(t, drafts, 2) }) } @@ -90,56 +94,52 @@ func testUpdateDraft(t *testing.T, ss store.Store) { channel := &model.Channel{ Id: model.NewId(), } - channel2 := &model.Channel{ - Id: model.NewId(), - } - member1 := &model.ChannelMember{ + member := &model.ChannelMember{ ChannelId: channel.Id, UserId: user.Id, NotifyProps: model.GetDefaultChannelNotifyProps(), } - member2 := &model.ChannelMember{ - ChannelId: channel2.Id, - UserId: user.Id, - NotifyProps: model.GetDefaultChannelNotifyProps(), - } - - _, err := ss.Channel().SaveMember(member1) + _, err := ss.Channel().SaveMember(member) require.NoError(t, err) - _, err = ss.Channel().SaveMember(member2) - require.NoError(t, err) - - draft1 := &model.Draft{ - CreateAt: 00001, - UpdateAt: 00001, - UserId: user.Id, - ChannelId: channel.Id, - Message: "draft1", - } - - draft2 := &model.Draft{ - CreateAt: 00005, - UpdateAt: 00005, - UserId: user.Id, - ChannelId: channel2.Id, - Message: "draft2", - } - t.Run("update drafts", func(t *testing.T) { - draftResp, err := ss.Draft().Update(draft1) + draft := &model.Draft{ + UserId: user.Id, + ChannelId: channel.Id, + Message: "draft", + } + _, err := ss.Draft().Upsert(draft) assert.NoError(t, err) - assert.Equal(t, draft1.Message, draftResp.Message) - assert.Equal(t, draft1.ChannelId, draftResp.ChannelId) + drafts, err := ss.Draft().GetDraftsForUser(user.Id, "") + assert.NoError(t, err) + assert.Len(t, drafts, 1) + draft1 := drafts[0] - draftResp, err = ss.Draft().Update(draft2) + assert.Greater(t, draft1.CreateAt, int64(0)) + assert.Equal(t, draft1.UpdateAt, draft1.CreateAt) + assert.Equal(t, channel.Id, draft1.ChannelId) + assert.Equal(t, "draft", draft1.Message) + + updatedDraft := &model.Draft{ + UserId: user.Id, + ChannelId: channel.Id, + Message: "updatedDraft", + } + _, err = ss.Draft().Upsert(updatedDraft) assert.NoError(t, err) - assert.Equal(t, draft2.Message, draftResp.Message) - assert.Equal(t, draft2.ChannelId, draftResp.ChannelId) + drafts, err = ss.Draft().GetDraftsForUser(user.Id, "") + assert.NoError(t, err) + assert.Len(t, drafts, 1) + draft2 := drafts[0] + + assert.Greater(t, draft2.CreateAt, int64(0)) + assert.Equal(t, "updatedDraft", draft2.Message) + assert.Equal(t, channel.Id, draft2.ChannelId) + assert.Equal(t, draft1.CreateAt, draft2.CreateAt) }) } @@ -189,10 +189,10 @@ func testDeleteDraft(t *testing.T, ss store.Store) { Message: "draft2", } - _, err = ss.Draft().Save(draft1) + _, err = ss.Draft().Upsert(draft1) require.NoError(t, err) - _, err = ss.Draft().Save(draft2) + _, err = ss.Draft().Upsert(draft2) require.NoError(t, err) t.Run("delete drafts", func(t *testing.T) { @@ -258,10 +258,10 @@ func testGetDraft(t *testing.T, ss store.Store) { Message: "draft2", } - _, err = ss.Draft().Save(draft1) + _, err = ss.Draft().Upsert(draft1) require.NoError(t, err) - _, err = ss.Draft().Save(draft2) + _, err = ss.Draft().Upsert(draft2) require.NoError(t, err) t.Run("get drafts", func(t *testing.T) { @@ -326,35 +326,28 @@ func testGetDraftsForUser(t *testing.T, ss store.Store) { require.NoError(t, err) draft1 := &model.Draft{ - CreateAt: 00001, - UpdateAt: 00001, UserId: user.Id, ChannelId: channel.Id, Message: "draft1", } draft2 := &model.Draft{ - CreateAt: 00005, - UpdateAt: 00005, UserId: user.Id, ChannelId: channel2.Id, Message: "draft2", } - _, err = ss.Draft().Save(draft1) + _, err = ss.Draft().Upsert(draft1) require.NoError(t, err) - _, err = ss.Draft().Save(draft2) + _, err = ss.Draft().Upsert(draft2) require.NoError(t, err) t.Run("get drafts", func(t *testing.T) { draftResp, err := ss.Draft().GetDraftsForUser(user.Id, "") assert.NoError(t, err) + assert.Len(t, draftResp, 2) - assert.Equal(t, draft2.Message, draftResp[0].Message) - assert.Equal(t, draft2.ChannelId, draftResp[0].ChannelId) - - assert.Equal(t, draft1.Message, draftResp[1].Message) - assert.Equal(t, draft1.ChannelId, draftResp[1].ChannelId) + assert.ElementsMatch(t, []*model.Draft{draft1, draft2}, draftResp) }) } diff --git a/server/channels/store/storetest/mocks/DraftStore.go b/server/channels/store/storetest/mocks/DraftStore.go index 1eb7d5f17b..924bb49883 100644 --- a/server/channels/store/storetest/mocks/DraftStore.go +++ b/server/channels/store/storetest/mocks/DraftStore.go @@ -74,31 +74,8 @@ func (_m *DraftStore) GetDraftsForUser(userID string, teamID string) ([]*model.D return r0, r1 } -// Save provides a mock function with given fields: d -func (_m *DraftStore) Save(d *model.Draft) (*model.Draft, error) { - ret := _m.Called(d) - - var r0 *model.Draft - if rf, ok := ret.Get(0).(func(*model.Draft) *model.Draft); ok { - r0 = rf(d) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*model.Draft) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(*model.Draft) error); ok { - r1 = rf(d) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// Update provides a mock function with given fields: d -func (_m *DraftStore) Update(d *model.Draft) (*model.Draft, error) { +// Upsert provides a mock function with given fields: d +func (_m *DraftStore) Upsert(d *model.Draft) (*model.Draft, error) { ret := _m.Called(d) var r0 *model.Draft diff --git a/server/channels/store/timerlayer/timerlayer.go b/server/channels/store/timerlayer/timerlayer.go index 8199138ac2..2c156ccbe5 100644 --- a/server/channels/store/timerlayer/timerlayer.go +++ b/server/channels/store/timerlayer/timerlayer.go @@ -3038,10 +3038,10 @@ func (s *TimerLayerDraftStore) GetDraftsForUser(userID string, teamID string) ([ return result, err } -func (s *TimerLayerDraftStore) Save(d *model.Draft) (*model.Draft, error) { +func (s *TimerLayerDraftStore) Upsert(d *model.Draft) (*model.Draft, error) { start := time.Now() - result, err := s.DraftStore.Save(d) + result, err := s.DraftStore.Upsert(d) elapsed := float64(time.Since(start)) / float64(time.Second) if s.Root.Metrics != nil { @@ -3049,23 +3049,7 @@ func (s *TimerLayerDraftStore) Save(d *model.Draft) (*model.Draft, error) { if err == nil { success = "true" } - s.Root.Metrics.ObserveStoreMethodDuration("DraftStore.Save", success, elapsed) - } - return result, err -} - -func (s *TimerLayerDraftStore) Update(d *model.Draft) (*model.Draft, error) { - start := time.Now() - - result, err := s.DraftStore.Update(d) - - elapsed := float64(time.Since(start)) / float64(time.Second) - if s.Root.Metrics != nil { - success := "false" - if err == nil { - success = "true" - } - s.Root.Metrics.ObserveStoreMethodDuration("DraftStore.Update", success, elapsed) + s.Root.Metrics.ObserveStoreMethodDuration("DraftStore.Upsert", success, elapsed) } return result, err } diff --git a/server/i18n/en.json b/server/i18n/en.json index 3e72fa3edc..506628e0bb 100644 --- a/server/i18n/en.json +++ b/server/i18n/en.json @@ -5027,10 +5027,6 @@ "id": "app.draft.save.app_error", "translation": "Unable to save the Draft." }, - { - "id": "app.draft.update.app_error", - "translation": "Unable to update the Draft." - }, { "id": "app.email.no_rate_limiter.app_error", "translation": "Rate limiter is not set up."