diff --git a/server/channels/store/sqlstore/sqlx_wrapper.go b/server/channels/store/sqlstore/sqlx_wrapper.go index bd294ba628..7a9a461085 100644 --- a/server/channels/store/sqlstore/sqlx_wrapper.go +++ b/server/channels/store/sqlstore/sqlx_wrapper.go @@ -20,6 +20,7 @@ import ( "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/mattermost/server/v8/channels/store/storetest" + sq "github.com/mattermost/squirrel" ) type StoreTestWrapper struct { @@ -38,6 +39,10 @@ func (w *StoreTestWrapper) DriverName() string { return w.orig.DriverName() } +func (w *StoreTestWrapper) GetQueryPlaceholder() sq.PlaceholderFormat { + return w.orig.getQueryPlaceholder() +} + type Builder interface { ToSql() (string, []any, error) } diff --git a/server/channels/store/sqlstore/status_store.go b/server/channels/store/sqlstore/status_store.go index 40f55773cc..1922d3a621 100644 --- a/server/channels/store/sqlstore/status_store.go +++ b/server/channels/store/sqlstore/status_store.go @@ -17,10 +17,28 @@ import ( type SqlStatusStore struct { *SqlStore + + statusSelectQuery sq.SelectBuilder } func newSqlStatusStore(sqlStore *SqlStore) store.StatusStore { - return &SqlStatusStore{sqlStore} + s := SqlStatusStore{ + SqlStore: sqlStore, + } + + manualColumnName := quoteColumnName(s.DriverName(), "Manual") + s.statusSelectQuery = s.getQueryBuilder(). + Select( + "COALESCE(UserId, '') AS UserId", + "COALESCE(Status, '') AS Status", + fmt.Sprintf("COALESCE(%s, FALSE) AS %s", manualColumnName, manualColumnName), + "COALESCE(LastActivityAt, 0) AS LastActivityAt", + "COALESCE(DNDEndTime, 0) AS DNDEndTime", + "COALESCE(PrevStatus, '') AS PrevStatus", + ). + From("Status") + + return &s } func (s SqlStatusStore) SaveOrUpdate(st *model.Status) error { @@ -37,12 +55,7 @@ func (s SqlStatusStore) SaveOrUpdate(st *model.Status) error { st.Status, st.Manual, st.LastActivityAt, st.DNDEndTime, st.PrevStatus)) } - queryString, args, err := query.ToSql() - if err != nil { - return errors.Wrap(err, "status_tosql") - } - - if _, err := s.GetMaster().Exec(queryString, args...); err != nil { + if _, err := s.GetMaster().ExecBuilder(query); err != nil { return errors.Wrap(err, "failed to upsert Status") } @@ -50,9 +63,10 @@ func (s SqlStatusStore) SaveOrUpdate(st *model.Status) error { } func (s SqlStatusStore) Get(userId string) (*model.Status, error) { - var status model.Status + query := s.statusSelectQuery.Where(sq.Eq{"UserId": userId}) - if err := s.GetReplica().Get(&status, "SELECT * FROM Status WHERE UserId = ?", userId); err != nil { + var status model.Status + if err := s.GetReplica().GetBuilder(&status, query); err != nil { if err == sql.ErrNoRows { return nil, store.NewErrNotFound("Status", fmt.Sprintf("userId=%s", userId)) } @@ -62,30 +76,13 @@ func (s SqlStatusStore) Get(userId string) (*model.Status, error) { } func (s SqlStatusStore) GetByIds(userIds []string) ([]*model.Status, error) { - query := s.getQueryBuilder(). - Select(fmt.Sprintf("UserId, Status, %s, LastActivityAt", quoteColumnName(s.DriverName(), "Manual"))). - From("Status"). - Where(sq.Eq{"UserId": userIds}) - queryString, args, err := query.ToSql() - if err != nil { - return nil, errors.Wrap(err, "status_tosql") - } - rows, err := s.GetReplica().DB.Query(queryString, args...) + query := s.statusSelectQuery.Where(sq.Eq{"UserId": userIds}) + + statuses := []*model.Status{} + err := s.GetReplica().SelectBuilder(&statuses, query) if err != nil { return nil, errors.Wrap(err, "failed to find Statuses") } - statuses := []*model.Status{} - defer rows.Close() - for rows.Next() { - var status model.Status - if err = rows.Scan(&status.UserId, &status.Status, &status.Manual, &status.LastActivityAt); err != nil { - return nil, errors.Wrap(err, "unable to scan from rows") - } - statuses = append(statuses, &status) - } - if err = rows.Err(); err != nil { - return nil, errors.Wrap(err, "failed while iterating over rows") - } return statuses, nil } @@ -94,24 +91,19 @@ func (s SqlStatusStore) GetByIds(userIds []string) ([]*model.Status, error) { func (s SqlStatusStore) updateExpiredStatuses(t *sqlxTxWrapper) ([]*model.Status, error) { statuses := []*model.Status{} currUnixTime := time.Now().UTC().Unix() - selectQuery, selectParams, err := s.getQueryBuilder(). - Select("*"). - From("Status"). - Where( - sq.And{ - sq.Eq{"Status": model.StatusDnd}, - sq.Gt{"DNDEndTime": 0}, - sq.LtOrEq{"DNDEndTime": currUnixTime}, - }, - ).ToSql() - if err != nil { - return nil, errors.Wrap(err, "status_tosql") - } - err = t.Select(&statuses, selectQuery, selectParams...) + selectQuery := s.statusSelectQuery.Where( + sq.And{ + sq.Eq{"Status": model.StatusDnd}, + sq.Gt{"DNDEndTime": 0}, + sq.LtOrEq{"DNDEndTime": currUnixTime}, + }, + ) + + err := t.SelectBuilder(&statuses, selectQuery) if err != nil { return nil, errors.Wrap(err, "updateExpiredStatusesT: failed to get expired dnd statuses") } - updateQuery, args, err := s.getQueryBuilder(). + updateQuery := s.getQueryBuilder(). Update("Status"). Where( sq.And{ @@ -123,14 +115,9 @@ func (s SqlStatusStore) updateExpiredStatuses(t *sqlxTxWrapper) ([]*model.Status Set("Status", sq.Expr("PrevStatus")). Set("PrevStatus", model.StatusDnd). Set("DNDEndTime", 0). - Set(quoteColumnName(s.DriverName(), "Manual"), false). - ToSql() + Set(quoteColumnName(s.DriverName(), "Manual"), false) - if err != nil { - return nil, errors.Wrap(err, "status_tosql") - } - - if _, err := t.Exec(updateQuery, args...); err != nil { + if _, err := t.ExecBuilder(updateQuery); err != nil { return nil, errors.Wrapf(err, "updateExpiredStatusesT: failed to update statuses") } @@ -162,7 +149,7 @@ func (s SqlStatusStore) UpdateExpiredDNDStatuses() (_ []*model.Status, err error return statuses, nil } - queryString, args, err := s.getQueryBuilder(). + queryString := s.getQueryBuilder(). Update("Status"). Where( sq.And{ @@ -175,30 +162,13 @@ func (s SqlStatusStore) UpdateExpiredDNDStatuses() (_ []*model.Status, err error Set("PrevStatus", model.StatusDnd). Set("DNDEndTime", 0). Set(quoteColumnName(s.DriverName(), "Manual"), false). - Suffix("RETURNING *"). - ToSql() + Suffix("RETURNING *") - if err != nil { - return nil, errors.Wrap(err, "status_tosql") - } - - rows, err := s.GetMaster().Query(queryString, args...) + statuses := []*model.Status{} + err = s.GetMaster().SelectBuilder(&statuses, queryString) if err != nil { return nil, errors.Wrap(err, "failed to find Statuses") } - defer rows.Close() - statuses := []*model.Status{} - for rows.Next() { - var status model.Status - if err = rows.Scan(&status.UserId, &status.Status, &status.Manual, &status.LastActivityAt, - &status.DNDEndTime, &status.PrevStatus); err != nil { - return nil, errors.Wrap(err, "unable to scan from rows") - } - statuses = append(statuses, &status) - } - if err = rows.Err(); err != nil { - return nil, errors.Wrap(err, "failed while iterating over rows") - } return statuses, nil } @@ -212,8 +182,13 @@ func (s SqlStatusStore) ResetAll() error { func (s SqlStatusStore) GetTotalActiveUsersCount() (int64, error) { time := model.GetMillis() - (1000 * 60 * 60 * 24) + query := s.getQueryBuilder(). + Select("COUNT(UserId)"). + From("Status"). + Where(sq.Gt{"LastActivityAt": time}) + var count int64 - err := s.GetReplica().Get(&count, "SELECT COUNT(UserId) FROM Status WHERE LastActivityAt > ?", time) + err := s.GetReplica().GetBuilder(&count, query) if err != nil { return count, errors.Wrap(err, "failed to count active users") } @@ -221,7 +196,12 @@ func (s SqlStatusStore) GetTotalActiveUsersCount() (int64, error) { } func (s SqlStatusStore) UpdateLastActivityAt(userId string, lastActivityAt int64) error { - if _, err := s.GetMaster().Exec("UPDATE Status SET LastActivityAt = ? WHERE UserId = ?", lastActivityAt, userId); err != nil { + builder := s.getQueryBuilder(). + Update("Status"). + Set("LastActivityAt", lastActivityAt). + Where(sq.Eq{"UserId": userId}) + + if _, err := s.GetMaster().ExecBuilder(builder); err != nil { return errors.Wrapf(err, "failed to update last activity for userId=%s", userId) } diff --git a/server/channels/store/sqlstore/status_store_test.go b/server/channels/store/sqlstore/status_store_test.go index a6f1d9a454..1f0b049f50 100644 --- a/server/channels/store/sqlstore/status_store_test.go +++ b/server/channels/store/sqlstore/status_store_test.go @@ -10,5 +10,5 @@ import ( ) func TestStatusStore(t *testing.T) { - StoreTest(t, storetest.TestStatusStore) + StoreTestWithSqlStore(t, storetest.TestStatusStore) } diff --git a/server/channels/store/storetest/channel_store.go b/server/channels/store/storetest/channel_store.go index abd9651c3d..521f61c26c 100644 --- a/server/channels/store/storetest/channel_store.go +++ b/server/channels/store/storetest/channel_store.go @@ -24,11 +24,13 @@ import ( "github.com/mattermost/mattermost/server/public/shared/timezones" "github.com/mattermost/mattermost/server/v8/channels/store" "github.com/mattermost/mattermost/server/v8/channels/utils" + sq "github.com/mattermost/squirrel" ) type SqlStore interface { GetMaster() SqlXExecutor DriverName() string + GetQueryPlaceholder() sq.PlaceholderFormat } type SqlXExecutor interface { @@ -5766,6 +5768,7 @@ func (s ByChannelDisplayName) Len() int { return len(s) } func (s ByChannelDisplayName) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + func (s ByChannelDisplayName) Less(i, j int) bool { if s[i].DisplayName != s[j].DisplayName { return s[i].DisplayName < s[j].DisplayName diff --git a/server/channels/store/storetest/status_store.go b/server/channels/store/storetest/status_store.go index afc456571d..9215dc7fb3 100644 --- a/server/channels/store/storetest/status_store.go +++ b/server/channels/store/storetest/status_store.go @@ -7,6 +7,8 @@ import ( "testing" "time" + sq "github.com/mattermost/squirrel" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/mattermost/mattermost/server/public/model" @@ -14,13 +16,15 @@ import ( "github.com/mattermost/mattermost/server/v8/channels/store" ) -func TestStatusStore(t *testing.T, rctx request.CTX, ss store.Store) { +func TestStatusStore(t *testing.T, rctx request.CTX, ss store.Store, s SqlStore) { t.Run("", func(t *testing.T) { testStatusStore(t, rctx, ss) }) t.Run("ActiveUserCount", func(t *testing.T) { testActiveUserCount(t, rctx, ss) }) t.Run("UpdateExpiredDNDStatuses", func(t *testing.T) { testUpdateExpiredDNDStatuses(t, rctx, ss) }) + t.Run("Get", func(t *testing.T) { testStatusGet(t, rctx, ss, s) }) + t.Run("GetByIds", func(t *testing.T) { testStatusGetByIds(t, rctx, ss, s) }) } -func testStatusStore(t *testing.T, rctx request.CTX, ss store.Store) { +func testStatusStore(t *testing.T, _ request.CTX, ss store.Store) { status := &model.Status{UserId: model.NewId(), Status: model.StatusOnline, Manual: false, LastActivityAt: 0, ActiveChannel: ""} require.NoError(t, ss.Status().SaveOrUpdate(status)) @@ -51,12 +55,19 @@ func testStatusStore(t *testing.T, rctx request.CTX, ss store.Store) { } func testActiveUserCount(t *testing.T, rctx request.CTX, ss store.Store) { - status := &model.Status{UserId: model.NewId(), Status: model.StatusOnline, Manual: false, LastActivityAt: model.GetMillis(), ActiveChannel: ""} - require.NoError(t, ss.Status().SaveOrUpdate(status)) + status1 := &model.Status{UserId: model.NewId(), Status: model.StatusOnline, Manual: false, LastActivityAt: model.GetMillis(), ActiveChannel: ""} + require.NoError(t, ss.Status().SaveOrUpdate(status1)) - count, err := ss.Status().GetTotalActiveUsersCount() + count1, err := ss.Status().GetTotalActiveUsersCount() require.NoError(t, err) - require.True(t, count > 0, "expected count > 0, got %d", count) + assert.Greater(t, count1, int64(0)) + + status2 := &model.Status{UserId: model.NewId(), Status: model.StatusOnline, Manual: false, LastActivityAt: model.GetMillis(), ActiveChannel: ""} + require.NoError(t, ss.Status().SaveOrUpdate(status2)) + + count2, err := ss.Status().GetTotalActiveUsersCount() + require.NoError(t, err) + assert.Equal(t, count1+1, count2) } type ByUserId []*model.Status @@ -68,8 +79,15 @@ func (s ByUserId) Less(i, j int) bool { return s[i].UserId < s[j].UserId } func testUpdateExpiredDNDStatuses(t *testing.T, rctx request.CTX, ss store.Store) { userID := NewTestID() - status := &model.Status{UserId: userID, Status: model.StatusDnd, Manual: true, - DNDEndTime: time.Now().Add(5 * time.Second).Unix(), PrevStatus: model.StatusOnline} + status := &model.Status{ + UserId: userID, + Status: model.StatusDnd, + Manual: true, + LastActivityAt: time.Now().Unix(), + ActiveChannel: "channel-id", + DNDEndTime: time.Now().Add(5 * time.Second).Unix(), + PrevStatus: model.StatusOnline, + } require.NoError(t, ss.Status().SaveOrUpdate(status)) time.Sleep(2 * time.Second) @@ -87,9 +105,181 @@ func testUpdateExpiredDNDStatuses(t *testing.T, rctx request.CTX, ss store.Store require.Len(t, statuses, 1) updatedStatus := *statuses[0] - require.Equal(t, updatedStatus.UserId, userID) - require.Equal(t, updatedStatus.Status, model.StatusOnline) - require.Equal(t, updatedStatus.DNDEndTime, int64(0)) - require.Equal(t, updatedStatus.PrevStatus, model.StatusDnd) - require.Equal(t, updatedStatus.Manual, false) + assert.Equal(t, updatedStatus.UserId, userID) + assert.Equal(t, updatedStatus.Status, model.StatusOnline) + assert.Equal(t, updatedStatus.Manual, false) + assert.Equal(t, updatedStatus.LastActivityAt, updatedStatus.LastActivityAt) + assert.Empty(t, updatedStatus.ActiveChannel) + assert.Equal(t, updatedStatus.DNDEndTime, int64(0)) + assert.Equal(t, updatedStatus.PrevStatus, model.StatusDnd) +} + +func insertNullStatus(t *testing.T, ss store.Store, s SqlStore) string { + userId := model.NewId() + db := ss.GetInternalMasterDB() + + // Insert status with explicit NULL values + builder := sq.StatementBuilder.PlaceholderFormat(s.GetQueryPlaceholder()). + Insert("Status"). + Columns("UserId", "Status", quoteColumnName(s.DriverName(), "Manual"), "LastActivityAt", "DNDEndTime", "PrevStatus"). + Values(userId, nil, nil, nil, nil, nil) + + query, args, err := builder.ToSql() + require.NoError(t, err) + + _, err = db.Exec(query, args...) + require.NoError(t, err) + + return userId +} + +func testStatusGet(t *testing.T, _ request.CTX, ss store.Store, s SqlStore) { + t.Run("null columns", func(t *testing.T) { + userId := insertNullStatus(t, ss, s) + + received, err := ss.Status().Get(userId) + require.NoError(t, err) + assert.Equal(t, userId, received.UserId) + assert.Empty(t, received.Status) + assert.False(t, received.Manual) + assert.Equal(t, int64(0), received.LastActivityAt) + assert.Empty(t, received.ActiveChannel) + assert.Equal(t, int64(0), received.DNDEndTime) + assert.Empty(t, received.PrevStatus) + }) + + t.Run("status1", func(t *testing.T) { + status1 := &model.Status{ + UserId: model.NewId(), + Status: model.StatusDnd, + Manual: true, + LastActivityAt: 1234, + ActiveChannel: "channel-id", + DNDEndTime: model.GetMillis(), + PrevStatus: model.StatusOnline, + } + require.NoError(t, ss.Status().SaveOrUpdate(status1)) + + received, err := ss.Status().Get(status1.UserId) + require.NoError(t, err) + assert.Equal(t, status1.UserId, received.UserId) + assert.Equal(t, status1.Status, received.Status) + assert.Equal(t, status1.Manual, received.Manual) + assert.Equal(t, status1.LastActivityAt, received.LastActivityAt) + assert.Empty(t, received.ActiveChannel) + assert.Equal(t, status1.DNDEndTime, received.DNDEndTime) + assert.Equal(t, status1.PrevStatus, received.PrevStatus) + }) + + t.Run("status2", func(t *testing.T) { + status2 := &model.Status{ + UserId: model.NewId(), + Status: model.StatusOffline, + Manual: false, + LastActivityAt: 12345, + ActiveChannel: "channel-id2", + DNDEndTime: model.GetMillis(), + PrevStatus: model.StatusAway, + } + require.NoError(t, ss.Status().SaveOrUpdate(status2)) + + received, err := ss.Status().Get(status2.UserId) + require.NoError(t, err) + assert.Equal(t, status2.UserId, received.UserId) + assert.Equal(t, status2.Status, received.Status) + assert.Equal(t, status2.Manual, received.Manual) + assert.Equal(t, status2.LastActivityAt, received.LastActivityAt) + assert.Empty(t, received.ActiveChannel) + assert.Equal(t, status2.DNDEndTime, received.DNDEndTime) + assert.Equal(t, status2.PrevStatus, received.PrevStatus) + }) +} + +func testStatusGetByIds(t *testing.T, _ request.CTX, ss store.Store, s SqlStore) { + t.Run("null columns, single user", func(t *testing.T) { + userId := insertNullStatus(t, ss, s) + + received, err := ss.Status().GetByIds([]string{userId}) + require.NoError(t, err) + require.Len(t, received, 1) + assert.Equal(t, userId, received[0].UserId) + assert.Empty(t, received[0].Status) + assert.False(t, received[0].Manual) + assert.Equal(t, int64(0), received[0].LastActivityAt) + assert.Empty(t, received[0].ActiveChannel) + assert.Equal(t, int64(0), received[0].DNDEndTime) + assert.Empty(t, received[0].PrevStatus) + }) + + t.Run("single user", func(t *testing.T) { + status1 := &model.Status{ + UserId: model.NewId(), + Status: model.StatusDnd, + Manual: true, + LastActivityAt: 1234, + ActiveChannel: "channel-id", + DNDEndTime: model.GetMillis(), + PrevStatus: model.StatusOnline, + } + require.NoError(t, ss.Status().SaveOrUpdate(status1)) + + received, err := ss.Status().GetByIds([]string{status1.UserId}) + require.NoError(t, err) + require.Len(t, received, 1) + assert.Equal(t, status1.UserId, received[0].UserId) + assert.Equal(t, status1.Status, received[0].Status) + assert.Equal(t, status1.Manual, received[0].Manual) + assert.Equal(t, status1.LastActivityAt, received[0].LastActivityAt) + assert.Empty(t, received[0].ActiveChannel) + assert.Equal(t, status1.DNDEndTime, received[0].DNDEndTime) + assert.Equal(t, status1.PrevStatus, received[0].PrevStatus) + }) + + t.Run("multiple users", func(t *testing.T) { + status1 := &model.Status{ + UserId: model.NewId(), + Status: model.StatusDnd, + Manual: true, + LastActivityAt: 1234, + ActiveChannel: "channel-id", + DNDEndTime: model.GetMillis(), + PrevStatus: model.StatusOnline, + } + require.NoError(t, ss.Status().SaveOrUpdate(status1)) + + status2 := &model.Status{ + UserId: model.NewId(), + Status: model.StatusOffline, + Manual: false, + LastActivityAt: 12345, + ActiveChannel: "channel-id2", + DNDEndTime: model.GetMillis(), + PrevStatus: model.StatusAway, + } + require.NoError(t, ss.Status().SaveOrUpdate(status2)) + + received, err := ss.Status().GetByIds([]string{status1.UserId, status2.UserId}) + require.NoError(t, err) + require.Len(t, received, 2) + + for _, status := range received { + if status.UserId == status1.UserId { + assert.Equal(t, status1.UserId, status.UserId) + assert.Equal(t, status1.Status, status.Status) + assert.Equal(t, status1.Manual, status.Manual) + assert.Equal(t, status1.LastActivityAt, status.LastActivityAt) + assert.Empty(t, status.ActiveChannel) + assert.Equal(t, status1.DNDEndTime, status.DNDEndTime) + assert.Equal(t, status1.PrevStatus, status.PrevStatus) + } else { + assert.Equal(t, status2.UserId, status.UserId) + assert.Equal(t, status2.Status, status.Status) + assert.Equal(t, status2.Manual, status.Manual) + assert.Equal(t, status2.LastActivityAt, status.LastActivityAt) + assert.Empty(t, status.ActiveChannel) + assert.Equal(t, status2.DNDEndTime, status.DNDEndTime) + assert.Equal(t, status2.PrevStatus, status.PrevStatus) + } + } + }) } diff --git a/server/channels/store/storetest/utils.go b/server/channels/store/storetest/utils.go index 5a9995aa3c..5f437415b7 100644 --- a/server/channels/store/storetest/utils.go +++ b/server/channels/store/storetest/utils.go @@ -4,6 +4,8 @@ package storetest import ( + "fmt" + "github.com/mattermost/mattermost/server/public/model" ) @@ -19,3 +21,16 @@ func NewTestID() string { return string(newID) } + +// Adds backtiks to the column name for MySQL, this is required if +// the column name is a reserved keyword. +// +// `ColumnName` - MySQL +// ColumnName - Postgres +func quoteColumnName(driver string, columnName string) string { + if driver == model.DatabaseDriverMysql { + return fmt.Sprintf("`%s`", columnName) + } + + return columnName +}