MM-62142: remove SELECT * from status store (round 2) (#30060)

* Revert "Revert "[MM-62142] Avoid SELECT * in status_store.go (#29610)" (#29985)"

This reverts commit d345e92136.

* add tests for StatusGet and StatusGetByIds

* handle NULL columns in the Status Store

* simplify status store

* more builder simplifications and tests

* expose GetQueryPlaceholder

---------

Co-authored-by: Mattermost Build <build@mattermost.com>
This commit is contained in:
Jesse Hallam 2025-02-06 12:36:31 -04:00 committed by GitHub
parent c4718e4542
commit 7d9521d783
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 282 additions and 89 deletions

View File

@ -20,6 +20,7 @@ import (
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/mlog"
"github.com/mattermost/mattermost/server/v8/channels/store/storetest"
sq "github.com/mattermost/squirrel"
)
type StoreTestWrapper struct {
@ -38,6 +39,10 @@ func (w *StoreTestWrapper) DriverName() string {
return w.orig.DriverName()
}
func (w *StoreTestWrapper) GetQueryPlaceholder() sq.PlaceholderFormat {
return w.orig.getQueryPlaceholder()
}
type Builder interface {
ToSql() (string, []any, error)
}

View File

@ -17,10 +17,28 @@ import (
type SqlStatusStore struct {
*SqlStore
statusSelectQuery sq.SelectBuilder
}
func newSqlStatusStore(sqlStore *SqlStore) store.StatusStore {
return &SqlStatusStore{sqlStore}
s := SqlStatusStore{
SqlStore: sqlStore,
}
manualColumnName := quoteColumnName(s.DriverName(), "Manual")
s.statusSelectQuery = s.getQueryBuilder().
Select(
"COALESCE(UserId, '') AS UserId",
"COALESCE(Status, '') AS Status",
fmt.Sprintf("COALESCE(%s, FALSE) AS %s", manualColumnName, manualColumnName),
"COALESCE(LastActivityAt, 0) AS LastActivityAt",
"COALESCE(DNDEndTime, 0) AS DNDEndTime",
"COALESCE(PrevStatus, '') AS PrevStatus",
).
From("Status")
return &s
}
func (s SqlStatusStore) SaveOrUpdate(st *model.Status) error {
@ -37,12 +55,7 @@ func (s SqlStatusStore) SaveOrUpdate(st *model.Status) error {
st.Status, st.Manual, st.LastActivityAt, st.DNDEndTime, st.PrevStatus))
}
queryString, args, err := query.ToSql()
if err != nil {
return errors.Wrap(err, "status_tosql")
}
if _, err := s.GetMaster().Exec(queryString, args...); err != nil {
if _, err := s.GetMaster().ExecBuilder(query); err != nil {
return errors.Wrap(err, "failed to upsert Status")
}
@ -50,9 +63,10 @@ func (s SqlStatusStore) SaveOrUpdate(st *model.Status) error {
}
func (s SqlStatusStore) Get(userId string) (*model.Status, error) {
var status model.Status
query := s.statusSelectQuery.Where(sq.Eq{"UserId": userId})
if err := s.GetReplica().Get(&status, "SELECT * FROM Status WHERE UserId = ?", userId); err != nil {
var status model.Status
if err := s.GetReplica().GetBuilder(&status, query); err != nil {
if err == sql.ErrNoRows {
return nil, store.NewErrNotFound("Status", fmt.Sprintf("userId=%s", userId))
}
@ -62,30 +76,13 @@ func (s SqlStatusStore) Get(userId string) (*model.Status, error) {
}
func (s SqlStatusStore) GetByIds(userIds []string) ([]*model.Status, error) {
query := s.getQueryBuilder().
Select(fmt.Sprintf("UserId, Status, %s, LastActivityAt", quoteColumnName(s.DriverName(), "Manual"))).
From("Status").
Where(sq.Eq{"UserId": userIds})
queryString, args, err := query.ToSql()
if err != nil {
return nil, errors.Wrap(err, "status_tosql")
}
rows, err := s.GetReplica().DB.Query(queryString, args...)
query := s.statusSelectQuery.Where(sq.Eq{"UserId": userIds})
statuses := []*model.Status{}
err := s.GetReplica().SelectBuilder(&statuses, query)
if err != nil {
return nil, errors.Wrap(err, "failed to find Statuses")
}
statuses := []*model.Status{}
defer rows.Close()
for rows.Next() {
var status model.Status
if err = rows.Scan(&status.UserId, &status.Status, &status.Manual, &status.LastActivityAt); err != nil {
return nil, errors.Wrap(err, "unable to scan from rows")
}
statuses = append(statuses, &status)
}
if err = rows.Err(); err != nil {
return nil, errors.Wrap(err, "failed while iterating over rows")
}
return statuses, nil
}
@ -94,24 +91,19 @@ func (s SqlStatusStore) GetByIds(userIds []string) ([]*model.Status, error) {
func (s SqlStatusStore) updateExpiredStatuses(t *sqlxTxWrapper) ([]*model.Status, error) {
statuses := []*model.Status{}
currUnixTime := time.Now().UTC().Unix()
selectQuery, selectParams, err := s.getQueryBuilder().
Select("*").
From("Status").
Where(
selectQuery := s.statusSelectQuery.Where(
sq.And{
sq.Eq{"Status": model.StatusDnd},
sq.Gt{"DNDEndTime": 0},
sq.LtOrEq{"DNDEndTime": currUnixTime},
},
).ToSql()
if err != nil {
return nil, errors.Wrap(err, "status_tosql")
}
err = t.Select(&statuses, selectQuery, selectParams...)
)
err := t.SelectBuilder(&statuses, selectQuery)
if err != nil {
return nil, errors.Wrap(err, "updateExpiredStatusesT: failed to get expired dnd statuses")
}
updateQuery, args, err := s.getQueryBuilder().
updateQuery := s.getQueryBuilder().
Update("Status").
Where(
sq.And{
@ -123,14 +115,9 @@ func (s SqlStatusStore) updateExpiredStatuses(t *sqlxTxWrapper) ([]*model.Status
Set("Status", sq.Expr("PrevStatus")).
Set("PrevStatus", model.StatusDnd).
Set("DNDEndTime", 0).
Set(quoteColumnName(s.DriverName(), "Manual"), false).
ToSql()
Set(quoteColumnName(s.DriverName(), "Manual"), false)
if err != nil {
return nil, errors.Wrap(err, "status_tosql")
}
if _, err := t.Exec(updateQuery, args...); err != nil {
if _, err := t.ExecBuilder(updateQuery); err != nil {
return nil, errors.Wrapf(err, "updateExpiredStatusesT: failed to update statuses")
}
@ -162,7 +149,7 @@ func (s SqlStatusStore) UpdateExpiredDNDStatuses() (_ []*model.Status, err error
return statuses, nil
}
queryString, args, err := s.getQueryBuilder().
queryString := s.getQueryBuilder().
Update("Status").
Where(
sq.And{
@ -175,30 +162,13 @@ func (s SqlStatusStore) UpdateExpiredDNDStatuses() (_ []*model.Status, err error
Set("PrevStatus", model.StatusDnd).
Set("DNDEndTime", 0).
Set(quoteColumnName(s.DriverName(), "Manual"), false).
Suffix("RETURNING *").
ToSql()
Suffix("RETURNING *")
if err != nil {
return nil, errors.Wrap(err, "status_tosql")
}
rows, err := s.GetMaster().Query(queryString, args...)
statuses := []*model.Status{}
err = s.GetMaster().SelectBuilder(&statuses, queryString)
if err != nil {
return nil, errors.Wrap(err, "failed to find Statuses")
}
defer rows.Close()
statuses := []*model.Status{}
for rows.Next() {
var status model.Status
if err = rows.Scan(&status.UserId, &status.Status, &status.Manual, &status.LastActivityAt,
&status.DNDEndTime, &status.PrevStatus); err != nil {
return nil, errors.Wrap(err, "unable to scan from rows")
}
statuses = append(statuses, &status)
}
if err = rows.Err(); err != nil {
return nil, errors.Wrap(err, "failed while iterating over rows")
}
return statuses, nil
}
@ -212,8 +182,13 @@ func (s SqlStatusStore) ResetAll() error {
func (s SqlStatusStore) GetTotalActiveUsersCount() (int64, error) {
time := model.GetMillis() - (1000 * 60 * 60 * 24)
query := s.getQueryBuilder().
Select("COUNT(UserId)").
From("Status").
Where(sq.Gt{"LastActivityAt": time})
var count int64
err := s.GetReplica().Get(&count, "SELECT COUNT(UserId) FROM Status WHERE LastActivityAt > ?", time)
err := s.GetReplica().GetBuilder(&count, query)
if err != nil {
return count, errors.Wrap(err, "failed to count active users")
}
@ -221,7 +196,12 @@ func (s SqlStatusStore) GetTotalActiveUsersCount() (int64, error) {
}
func (s SqlStatusStore) UpdateLastActivityAt(userId string, lastActivityAt int64) error {
if _, err := s.GetMaster().Exec("UPDATE Status SET LastActivityAt = ? WHERE UserId = ?", lastActivityAt, userId); err != nil {
builder := s.getQueryBuilder().
Update("Status").
Set("LastActivityAt", lastActivityAt).
Where(sq.Eq{"UserId": userId})
if _, err := s.GetMaster().ExecBuilder(builder); err != nil {
return errors.Wrapf(err, "failed to update last activity for userId=%s", userId)
}

View File

@ -10,5 +10,5 @@ import (
)
func TestStatusStore(t *testing.T) {
StoreTest(t, storetest.TestStatusStore)
StoreTestWithSqlStore(t, storetest.TestStatusStore)
}

View File

@ -24,11 +24,13 @@ import (
"github.com/mattermost/mattermost/server/public/shared/timezones"
"github.com/mattermost/mattermost/server/v8/channels/store"
"github.com/mattermost/mattermost/server/v8/channels/utils"
sq "github.com/mattermost/squirrel"
)
type SqlStore interface {
GetMaster() SqlXExecutor
DriverName() string
GetQueryPlaceholder() sq.PlaceholderFormat
}
type SqlXExecutor interface {
@ -5766,6 +5768,7 @@ func (s ByChannelDisplayName) Len() int { return len(s) }
func (s ByChannelDisplayName) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
func (s ByChannelDisplayName) Less(i, j int) bool {
if s[i].DisplayName != s[j].DisplayName {
return s[i].DisplayName < s[j].DisplayName

View File

@ -7,6 +7,8 @@ import (
"testing"
"time"
sq "github.com/mattermost/squirrel"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/mattermost/mattermost/server/public/model"
@ -14,13 +16,15 @@ import (
"github.com/mattermost/mattermost/server/v8/channels/store"
)
func TestStatusStore(t *testing.T, rctx request.CTX, ss store.Store) {
func TestStatusStore(t *testing.T, rctx request.CTX, ss store.Store, s SqlStore) {
t.Run("", func(t *testing.T) { testStatusStore(t, rctx, ss) })
t.Run("ActiveUserCount", func(t *testing.T) { testActiveUserCount(t, rctx, ss) })
t.Run("UpdateExpiredDNDStatuses", func(t *testing.T) { testUpdateExpiredDNDStatuses(t, rctx, ss) })
t.Run("Get", func(t *testing.T) { testStatusGet(t, rctx, ss, s) })
t.Run("GetByIds", func(t *testing.T) { testStatusGetByIds(t, rctx, ss, s) })
}
func testStatusStore(t *testing.T, rctx request.CTX, ss store.Store) {
func testStatusStore(t *testing.T, _ request.CTX, ss store.Store) {
status := &model.Status{UserId: model.NewId(), Status: model.StatusOnline, Manual: false, LastActivityAt: 0, ActiveChannel: ""}
require.NoError(t, ss.Status().SaveOrUpdate(status))
@ -51,12 +55,19 @@ func testStatusStore(t *testing.T, rctx request.CTX, ss store.Store) {
}
func testActiveUserCount(t *testing.T, rctx request.CTX, ss store.Store) {
status := &model.Status{UserId: model.NewId(), Status: model.StatusOnline, Manual: false, LastActivityAt: model.GetMillis(), ActiveChannel: ""}
require.NoError(t, ss.Status().SaveOrUpdate(status))
status1 := &model.Status{UserId: model.NewId(), Status: model.StatusOnline, Manual: false, LastActivityAt: model.GetMillis(), ActiveChannel: ""}
require.NoError(t, ss.Status().SaveOrUpdate(status1))
count, err := ss.Status().GetTotalActiveUsersCount()
count1, err := ss.Status().GetTotalActiveUsersCount()
require.NoError(t, err)
require.True(t, count > 0, "expected count > 0, got %d", count)
assert.Greater(t, count1, int64(0))
status2 := &model.Status{UserId: model.NewId(), Status: model.StatusOnline, Manual: false, LastActivityAt: model.GetMillis(), ActiveChannel: ""}
require.NoError(t, ss.Status().SaveOrUpdate(status2))
count2, err := ss.Status().GetTotalActiveUsersCount()
require.NoError(t, err)
assert.Equal(t, count1+1, count2)
}
type ByUserId []*model.Status
@ -68,8 +79,15 @@ func (s ByUserId) Less(i, j int) bool { return s[i].UserId < s[j].UserId }
func testUpdateExpiredDNDStatuses(t *testing.T, rctx request.CTX, ss store.Store) {
userID := NewTestID()
status := &model.Status{UserId: userID, Status: model.StatusDnd, Manual: true,
DNDEndTime: time.Now().Add(5 * time.Second).Unix(), PrevStatus: model.StatusOnline}
status := &model.Status{
UserId: userID,
Status: model.StatusDnd,
Manual: true,
LastActivityAt: time.Now().Unix(),
ActiveChannel: "channel-id",
DNDEndTime: time.Now().Add(5 * time.Second).Unix(),
PrevStatus: model.StatusOnline,
}
require.NoError(t, ss.Status().SaveOrUpdate(status))
time.Sleep(2 * time.Second)
@ -87,9 +105,181 @@ func testUpdateExpiredDNDStatuses(t *testing.T, rctx request.CTX, ss store.Store
require.Len(t, statuses, 1)
updatedStatus := *statuses[0]
require.Equal(t, updatedStatus.UserId, userID)
require.Equal(t, updatedStatus.Status, model.StatusOnline)
require.Equal(t, updatedStatus.DNDEndTime, int64(0))
require.Equal(t, updatedStatus.PrevStatus, model.StatusDnd)
require.Equal(t, updatedStatus.Manual, false)
assert.Equal(t, updatedStatus.UserId, userID)
assert.Equal(t, updatedStatus.Status, model.StatusOnline)
assert.Equal(t, updatedStatus.Manual, false)
assert.Equal(t, updatedStatus.LastActivityAt, updatedStatus.LastActivityAt)
assert.Empty(t, updatedStatus.ActiveChannel)
assert.Equal(t, updatedStatus.DNDEndTime, int64(0))
assert.Equal(t, updatedStatus.PrevStatus, model.StatusDnd)
}
func insertNullStatus(t *testing.T, ss store.Store, s SqlStore) string {
userId := model.NewId()
db := ss.GetInternalMasterDB()
// Insert status with explicit NULL values
builder := sq.StatementBuilder.PlaceholderFormat(s.GetQueryPlaceholder()).
Insert("Status").
Columns("UserId", "Status", quoteColumnName(s.DriverName(), "Manual"), "LastActivityAt", "DNDEndTime", "PrevStatus").
Values(userId, nil, nil, nil, nil, nil)
query, args, err := builder.ToSql()
require.NoError(t, err)
_, err = db.Exec(query, args...)
require.NoError(t, err)
return userId
}
func testStatusGet(t *testing.T, _ request.CTX, ss store.Store, s SqlStore) {
t.Run("null columns", func(t *testing.T) {
userId := insertNullStatus(t, ss, s)
received, err := ss.Status().Get(userId)
require.NoError(t, err)
assert.Equal(t, userId, received.UserId)
assert.Empty(t, received.Status)
assert.False(t, received.Manual)
assert.Equal(t, int64(0), received.LastActivityAt)
assert.Empty(t, received.ActiveChannel)
assert.Equal(t, int64(0), received.DNDEndTime)
assert.Empty(t, received.PrevStatus)
})
t.Run("status1", func(t *testing.T) {
status1 := &model.Status{
UserId: model.NewId(),
Status: model.StatusDnd,
Manual: true,
LastActivityAt: 1234,
ActiveChannel: "channel-id",
DNDEndTime: model.GetMillis(),
PrevStatus: model.StatusOnline,
}
require.NoError(t, ss.Status().SaveOrUpdate(status1))
received, err := ss.Status().Get(status1.UserId)
require.NoError(t, err)
assert.Equal(t, status1.UserId, received.UserId)
assert.Equal(t, status1.Status, received.Status)
assert.Equal(t, status1.Manual, received.Manual)
assert.Equal(t, status1.LastActivityAt, received.LastActivityAt)
assert.Empty(t, received.ActiveChannel)
assert.Equal(t, status1.DNDEndTime, received.DNDEndTime)
assert.Equal(t, status1.PrevStatus, received.PrevStatus)
})
t.Run("status2", func(t *testing.T) {
status2 := &model.Status{
UserId: model.NewId(),
Status: model.StatusOffline,
Manual: false,
LastActivityAt: 12345,
ActiveChannel: "channel-id2",
DNDEndTime: model.GetMillis(),
PrevStatus: model.StatusAway,
}
require.NoError(t, ss.Status().SaveOrUpdate(status2))
received, err := ss.Status().Get(status2.UserId)
require.NoError(t, err)
assert.Equal(t, status2.UserId, received.UserId)
assert.Equal(t, status2.Status, received.Status)
assert.Equal(t, status2.Manual, received.Manual)
assert.Equal(t, status2.LastActivityAt, received.LastActivityAt)
assert.Empty(t, received.ActiveChannel)
assert.Equal(t, status2.DNDEndTime, received.DNDEndTime)
assert.Equal(t, status2.PrevStatus, received.PrevStatus)
})
}
func testStatusGetByIds(t *testing.T, _ request.CTX, ss store.Store, s SqlStore) {
t.Run("null columns, single user", func(t *testing.T) {
userId := insertNullStatus(t, ss, s)
received, err := ss.Status().GetByIds([]string{userId})
require.NoError(t, err)
require.Len(t, received, 1)
assert.Equal(t, userId, received[0].UserId)
assert.Empty(t, received[0].Status)
assert.False(t, received[0].Manual)
assert.Equal(t, int64(0), received[0].LastActivityAt)
assert.Empty(t, received[0].ActiveChannel)
assert.Equal(t, int64(0), received[0].DNDEndTime)
assert.Empty(t, received[0].PrevStatus)
})
t.Run("single user", func(t *testing.T) {
status1 := &model.Status{
UserId: model.NewId(),
Status: model.StatusDnd,
Manual: true,
LastActivityAt: 1234,
ActiveChannel: "channel-id",
DNDEndTime: model.GetMillis(),
PrevStatus: model.StatusOnline,
}
require.NoError(t, ss.Status().SaveOrUpdate(status1))
received, err := ss.Status().GetByIds([]string{status1.UserId})
require.NoError(t, err)
require.Len(t, received, 1)
assert.Equal(t, status1.UserId, received[0].UserId)
assert.Equal(t, status1.Status, received[0].Status)
assert.Equal(t, status1.Manual, received[0].Manual)
assert.Equal(t, status1.LastActivityAt, received[0].LastActivityAt)
assert.Empty(t, received[0].ActiveChannel)
assert.Equal(t, status1.DNDEndTime, received[0].DNDEndTime)
assert.Equal(t, status1.PrevStatus, received[0].PrevStatus)
})
t.Run("multiple users", func(t *testing.T) {
status1 := &model.Status{
UserId: model.NewId(),
Status: model.StatusDnd,
Manual: true,
LastActivityAt: 1234,
ActiveChannel: "channel-id",
DNDEndTime: model.GetMillis(),
PrevStatus: model.StatusOnline,
}
require.NoError(t, ss.Status().SaveOrUpdate(status1))
status2 := &model.Status{
UserId: model.NewId(),
Status: model.StatusOffline,
Manual: false,
LastActivityAt: 12345,
ActiveChannel: "channel-id2",
DNDEndTime: model.GetMillis(),
PrevStatus: model.StatusAway,
}
require.NoError(t, ss.Status().SaveOrUpdate(status2))
received, err := ss.Status().GetByIds([]string{status1.UserId, status2.UserId})
require.NoError(t, err)
require.Len(t, received, 2)
for _, status := range received {
if status.UserId == status1.UserId {
assert.Equal(t, status1.UserId, status.UserId)
assert.Equal(t, status1.Status, status.Status)
assert.Equal(t, status1.Manual, status.Manual)
assert.Equal(t, status1.LastActivityAt, status.LastActivityAt)
assert.Empty(t, status.ActiveChannel)
assert.Equal(t, status1.DNDEndTime, status.DNDEndTime)
assert.Equal(t, status1.PrevStatus, status.PrevStatus)
} else {
assert.Equal(t, status2.UserId, status.UserId)
assert.Equal(t, status2.Status, status.Status)
assert.Equal(t, status2.Manual, status.Manual)
assert.Equal(t, status2.LastActivityAt, status.LastActivityAt)
assert.Empty(t, status.ActiveChannel)
assert.Equal(t, status2.DNDEndTime, status.DNDEndTime)
assert.Equal(t, status2.PrevStatus, status.PrevStatus)
}
}
})
}

View File

@ -4,6 +4,8 @@
package storetest
import (
"fmt"
"github.com/mattermost/mattermost/server/public/model"
)
@ -19,3 +21,16 @@ func NewTestID() string {
return string(newID)
}
// Adds backtiks to the column name for MySQL, this is required if
// the column name is a reserved keyword.
//
// `ColumnName` - MySQL
// ColumnName - Postgres
func quoteColumnName(driver string, columnName string) string {
if driver == model.DatabaseDriverMysql {
return fmt.Sprintf("`%s`", columnName)
}
return columnName
}