From df08325ef6a33bfac41cda05809ea04c29ffc708 Mon Sep 17 00:00:00 2001 From: Agniva De Sarker Date: Fri, 11 Oct 2024 09:11:38 +0530 Subject: [PATCH] store: use ToSQL (#28681) - Improve some code to use ToSQL - Fix a bug in SqlReactionStore.GetForPostSince where OrderBy wasn't added properly. Added a test. - Improve a method to return a slice rather than pass a pointer to slice. ```release-note NONE ``` --- .../channels/store/sqlstore/reaction_store.go | 52 ++++------- server/channels/store/sqlstore/user_store.go | 91 ++++--------------- .../store/storetest/reaction_store.go | 8 +- 3 files changed, 41 insertions(+), 110 deletions(-) diff --git a/server/channels/store/sqlstore/reaction_store.go b/server/channels/store/sqlstore/reaction_store.go index f3437c704b..d559448861 100644 --- a/server/channels/store/sqlstore/reaction_store.go +++ b/server/channels/store/sqlstore/reaction_store.go @@ -38,10 +38,8 @@ func (s *SqlReactionStore) Save(reaction *model.Reaction) (re *model.Reaction, e if reaction.ChannelId == "" { // get channelId, if not already populated var channelIds []string - var args []interface{} query := "SELECT ChannelId from Posts where Id = ?" - args = append(args, reaction.PostId) - err = transaction.Select(&channelIds, query, args...) + err = transaction.Select(&channelIds, query, reaction.PostId) if err != nil { return nil, errors.Wrap(err, "failed while getting channelId from Posts") } @@ -89,21 +87,16 @@ func (s *SqlReactionStore) Delete(reaction *model.Reaction) (re *model.Reaction, // GetForPost returns all reactions associated with `postId` that are not deleted. func (s *SqlReactionStore) GetForPost(postId string, allowFromCache bool) ([]*model.Reaction, error) { - queryString, args, err := s.getQueryBuilder(). + builder := s.getQueryBuilder(). Select("UserId", "PostId", "EmojiName", "CreateAt", "COALESCE(UpdateAt, CreateAt) As UpdateAt", "COALESCE(DeleteAt, 0) As DeleteAt", "RemoteId", "ChannelId"). From("Reactions"). Where(sq.Eq{"PostId": postId}). Where(sq.Eq{"COALESCE(DeleteAt, 0)": 0}). - OrderBy("CreateAt"). - ToSql() - - if err != nil { - return nil, errors.Wrap(err, "reactions_getforpost_tosql") - } + OrderBy("CreateAt") var reactions []*model.Reaction - if err := s.GetReplicaX().Select(&reactions, queryString, args...); err != nil { + if err := s.GetReplicaX().SelectBuilder(&reactions, builder); err != nil { return nil, errors.Wrapf(err, "failed to get Reactions with postId=%s", postId) } return reactions, nil @@ -135,7 +128,8 @@ func (s *SqlReactionStore) GetForPostSince(postId string, since int64, excludeRe "COALESCE(DeleteAt, 0) As DeleteAt", "RemoteId"). From("Reactions"). Where(sq.Eq{"PostId": postId}). - Where(sq.Gt{"UpdateAt": since}) + Where(sq.Gt{"UpdateAt": since}). + OrderBy("CreateAt") if excludeRemoteId != "" { query = query.Where(sq.NotEq{"COALESCE(RemoteId, '')": excludeRemoteId}) @@ -145,15 +139,8 @@ func (s *SqlReactionStore) GetForPostSince(postId string, since int64, excludeRe query = query.Where(sq.Eq{"COALESCE(DeleteAt, 0)": 0}) } - query.OrderBy("CreateAt") - - queryString, args, err := query.ToSql() - if err != nil { - return nil, errors.Wrap(err, "reactions_getforpostsince_tosql") - } - var reactions []*model.Reaction - if err := s.GetReplicaX().Select(&reactions, queryString, args...); err != nil { + if err := s.GetReplicaX().SelectBuilder(&reactions, query); err != nil { return nil, errors.Wrapf(err, "failed to find reactions") } return reactions, nil @@ -210,13 +197,8 @@ func (s *SqlReactionStore) GetSingle(userID, postID, remoteID, emojiName string) Where(sq.Eq{"COALESCE(RemoteId, '')": remoteID}). Where(sq.Eq{"EmojiName": emojiName}) - queryString, args, err := query.ToSql() - if err != nil { - return nil, errors.Wrap(err, "reactions_getsingle_tosql") - } - var reactions []*model.Reaction - if err := s.GetReplicaX().Select(&reactions, queryString, args...); err != nil { + if err := s.GetReplicaX().SelectBuilder(&reactions, query); err != nil { return nil, errors.Wrapf(err, "failed to find reaction") } if len(reactions) == 0 { @@ -270,16 +252,17 @@ func (s *SqlReactionStore) DeleteAllWithEmojiName(emojiName string) error { return nil } -func (s *SqlReactionStore) permanentDeleteReactions(userId string, postIds *[]string) error { +func (s *SqlReactionStore) permanentDeleteReactions(userId string) ([]string, error) { txn, err := s.GetMasterX().Beginx() if err != nil { - return err + return nil, err } defer finalizeTransactionX(txn, &err) - err = txn.Select(postIds, "SELECT PostId FROM Reactions WHERE UserId = ?", userId) + postIds := []string{} + err = txn.Select(&postIds, "SELECT PostId FROM Reactions WHERE UserId = ?", userId) if err != nil { - return errors.Wrapf(err, "failed to get Reactions with userId=%s", userId) + return nil, errors.Wrapf(err, "failed to get Reactions with userId=%s", userId) } query := s.getQueryBuilder(). @@ -291,19 +274,18 @@ func (s *SqlReactionStore) permanentDeleteReactions(userId string, postIds *[]st _, err = txn.ExecBuilder(query) if err != nil { - return errors.Wrapf(err, "failed to delete reactions with userId=%s", userId) + return nil, errors.Wrapf(err, "failed to delete reactions with userId=%s", userId) } if err = txn.Commit(); err != nil { - return err + return nil, err } - return nil + return postIds, nil } func (s SqlReactionStore) PermanentDeleteByUser(userId string) error { now := model.GetMillis() - postIds := []string{} - err := s.permanentDeleteReactions(userId, &postIds) + postIds, err := s.permanentDeleteReactions(userId) if err != nil { return err } diff --git a/server/channels/store/sqlstore/user_store.go b/server/channels/store/sqlstore/user_store.go index ae1d908085..f336e117f3 100644 --- a/server/channels/store/sqlstore/user_store.go +++ b/server/channels/store/sqlstore/user_store.go @@ -141,25 +141,18 @@ func (us SqlUserStore) DeactivateGuests() ([]string, error) { Where(sq.Eq{"Roles": "system_guest"}). Where(sq.Eq{"DeleteAt": 0}) - queryString, args, err := updateQuery.ToSql() - if err != nil { - return nil, errors.Wrap(err, "deactivate_guests_tosql") - } - - _, err = us.GetMasterX().Exec(queryString, args...) + _, err := us.GetMasterX().ExecBuilder(updateQuery) if err != nil { return nil, errors.Wrap(err, "failed to update Users with roles=system_guest") } - selectQuery := us.getQueryBuilder().Select("Id").From("Users").Where(sq.Eq{"DeleteAt": curTime}) - - queryString, args, err = selectQuery.ToSql() - if err != nil { - return nil, errors.Wrap(err, "deactivate_guests_tosql") - } + selectQuery := us.getQueryBuilder(). + Select("Id"). + From("Users"). + Where(sq.Eq{"DeleteAt": curTime}) userIds := []string{} - err = us.GetMasterX().Select(&userIds, queryString, args...) + err = us.GetMasterX().SelectBuilder(&userIds, selectQuery) if err != nil { return nil, errors.Wrap(err, "failed to find Users") } @@ -349,12 +342,7 @@ func (us SqlUserStore) UpdateAuthData(userId string, service string, authData *s Set("MfaUsedTimestamps", model.StringArray{}) } - queryString, args, err := updateQuery.ToSql() - if err != nil { - return "", errors.Wrap(err, "update_auth_data_tosql") - } - - if _, err := us.GetMasterX().Exec(queryString, args...); err != nil { + if _, err := us.GetMasterX().ExecBuilder(updateQuery); err != nil { if IsUniqueConstraintError(err, []string{"Email", "users_email_key", "idx_users_email_unique", "AuthData", "users_authdata_key"}) { return "", store.NewErrInvalidInput("User", "id", userId) } @@ -364,20 +352,13 @@ func (us SqlUserStore) UpdateAuthData(userId string, service string, authData *s } func (us SqlUserStore) UpdateLastLogin(userId string, lastLogin int64) error { - updateAt := model.GetMillis() - updateQuery := us.getQueryBuilder(). Update("Users"). Set("LastLogin", lastLogin). - Set("UpdateAt", updateAt). + Set("UpdateAt", model.GetMillis()). Where(sq.Eq{"Id": userId}) - queryString, args, err := updateQuery.ToSql() - if err != nil { - return errors.Wrap(err, "update_last_login_tosql") - } - - if _, err := us.GetMasterX().Exec(queryString, args...); err != nil { + if _, err := us.GetMasterX().ExecBuilder(updateQuery); err != nil { return errors.Wrapf(err, "failed to update User with userId=%s", userId) } @@ -403,23 +384,15 @@ func (us SqlUserStore) ResetAuthDataToEmailForUsers(service string, userIDs []st Select("COUNT(*)"). From("Users"). Where(whereEquals) - query, args, err := builder.ToSql() - if err != nil { - return 0, errors.Wrap(err, "select_count_users_tosql") - } var numAffected int - err = us.GetReplicaX().Get(&numAffected, query, args...) + err := us.GetReplicaX().GetBuilder(&numAffected, builder) return numAffected, err } builder := us.getQueryBuilder(). Update("Users"). Set("AuthData", sq.Expr("Email")). Where(whereEquals) - query, args, err := builder.ToSql() - if err != nil { - return 0, errors.Wrap(err, "update_users_tosql") - } - result, err := us.GetMasterX().Exec(query, args...) + result, err := us.GetMasterX().ExecBuilder(builder) if err != nil { return 0, errors.Wrap(err, "failed to update users' AuthData") } @@ -481,13 +454,8 @@ func (us SqlUserStore) GetMfaUsedTimestamps(userId string) ([]int, error) { // GetMany returns a list of users for the provided list of ids func (us SqlUserStore) GetMany(ctx context.Context, ids []string) ([]*model.User, error) { query := us.usersQuery.Where(sq.Eq{"Id": ids}) - queryString, args, err := query.ToSql() - if err != nil { - return nil, errors.Wrap(err, "users_get_many_tosql") - } - users := []*model.User{} - if err := us.SqlStore.DBXFromContext(ctx).Select(&users, queryString, args...); err != nil { + if err := us.SqlStore.DBXFromContext(ctx).SelectBuilder(&users, query); err != nil { return nil, errors.Wrap(err, "users_get_many_select") } @@ -550,13 +518,8 @@ func (us SqlUserStore) GetAllAfter(limit int, afterId string) ([]*model.User, er OrderBy("Id ASC"). Limit(uint64(limit)) - queryString, args, err := query.ToSql() - if err != nil { - return nil, errors.Wrap(err, "get_all_after_tosql") - } - users := []*model.User{} - if err := us.GetReplicaX().Select(&users, queryString, args...); err != nil { + if err := us.GetReplicaX().SelectBuilder(&users, query); err != nil { return nil, errors.Wrap(err, "failed to find Users") } @@ -589,13 +552,8 @@ func (us SqlUserStore) GetAllProfiles(options *model.UserGetOptions) ([]*model.U query = query.Where("u.DeleteAt = 0") } - queryString, args, err := query.ToSql() - if err != nil { - return nil, errors.Wrap(err, "get_all_profiles_tosql") - } - users := []*model.User{} - if err := us.GetReplicaX().Select(&users, queryString, args...); err != nil { + if err := us.GetReplicaX().SelectBuilder(&users, query); err != nil { return nil, errors.Wrap(err, "failed to get User profiles") } @@ -766,13 +724,8 @@ func (us SqlUserStore) GetProfiles(options *model.UserGetOptions) ([]*model.User query = query.Where("u.DeleteAt = 0") } - queryString, args, err := query.ToSql() - if err != nil { - return nil, errors.Wrap(err, "get_etag_for_profiles_tosql") - } - users := []*model.User{} - if err := us.GetReplicaX().Select(&users, queryString, args...); err != nil { + if err := us.GetReplicaX().SelectBuilder(&users, query); err != nil { return nil, errors.Wrap(err, "failed to find Users") } @@ -802,13 +755,8 @@ func (us SqlUserStore) GetProfilesInChannel(options *model.UserGetOptions) ([]*m query = applyMultiRoleFilters(query, options.Roles, options.TeamRoles, options.ChannelRoles, us.DriverName() == model.DatabaseDriverPostgres) - queryString, args, err := query.ToSql() - if err != nil { - return nil, errors.Wrap(err, "get_profiles_in_channel_tosql") - } - users := []*model.User{} - if err := us.GetReplicaX().Select(&users, queryString, args...); err != nil { + if err := us.GetReplicaX().SelectBuilder(&users, query); err != nil { return nil, errors.Wrap(err, "failed to find Users") } @@ -841,13 +789,8 @@ func (us SqlUserStore) GetProfilesInChannelByStatus(options *model.UserGetOption query = query.Where("u.DeleteAt = 0") } - queryString, args, err := query.ToSql() - if err != nil { - return nil, errors.Wrap(err, "get_profiles_in_channel_by_status_tosql") - } - users := []*model.User{} - if err := us.GetReplicaX().Select(&users, queryString, args...); err != nil { + if err := us.GetReplicaX().SelectBuilder(&users, query); err != nil { return nil, errors.Wrap(err, "failed to find Users") } diff --git a/server/channels/store/storetest/reaction_store.go b/server/channels/store/storetest/reaction_store.go index e146aaaeeb..64018d13af 100644 --- a/server/channels/store/storetest/reaction_store.go +++ b/server/channels/store/storetest/reaction_store.go @@ -366,6 +366,8 @@ func testReactionGetForPostSince(t *testing.T, rctx request.CTX, ss store.Store, _, err := ss.Reaction().Save(reaction) require.NoError(t, err) + time.Sleep(5 * time.Millisecond) + if del > 0 { _, err = ss.Reaction().Delete(reaction) require.NoError(t, err) @@ -383,6 +385,7 @@ func testReactionGetForPostSince(t *testing.T, rctx request.CTX, ss store.Store, returned, err := ss.Reaction().GetForPostSince(postId, later-1, "", false) require.NoError(t, err) require.Len(t, returned, 2, "should've returned 2 non-deleted reactions") + assert.Less(t, returned[0].CreateAt, returned[1].CreateAt) for _, r := range returned { assert.Zero(t, r.DeleteAt, "should not have returned deleted reaction") } @@ -394,7 +397,10 @@ func testReactionGetForPostSince(t *testing.T, rctx request.CTX, ss store.Store, require.NoError(t, err) require.Len(t, returned, 3, "should've returned 3 reactions") var count int - for _, r := range returned { + for i, r := range returned { + if i > 0 { + assert.Less(t, returned[i-1].CreateAt, returned[i].CreateAt) + } if r.DeleteAt > 0 { count++ }