mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
MM-62142: remove SELECT * from status store (round 2) (#30060)
* Revert "Revert "[MM-62142] Avoid SELECT * in status_store.go (#29610)" (#29985)"
This reverts commit d345e92136
.
* add tests for StatusGet and StatusGetByIds
* handle NULL columns in the Status Store
* simplify status store
* more builder simplifications and tests
* expose GetQueryPlaceholder
---------
Co-authored-by: Mattermost Build <build@mattermost.com>
This commit is contained in:
parent
c4718e4542
commit
7d9521d783
@ -20,6 +20,7 @@ import (
|
||||
"github.com/mattermost/mattermost/server/public/model"
|
||||
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/store/storetest"
|
||||
sq "github.com/mattermost/squirrel"
|
||||
)
|
||||
|
||||
type StoreTestWrapper struct {
|
||||
@ -38,6 +39,10 @@ func (w *StoreTestWrapper) DriverName() string {
|
||||
return w.orig.DriverName()
|
||||
}
|
||||
|
||||
func (w *StoreTestWrapper) GetQueryPlaceholder() sq.PlaceholderFormat {
|
||||
return w.orig.getQueryPlaceholder()
|
||||
}
|
||||
|
||||
type Builder interface {
|
||||
ToSql() (string, []any, error)
|
||||
}
|
||||
|
@ -17,10 +17,28 @@ import (
|
||||
|
||||
type SqlStatusStore struct {
|
||||
*SqlStore
|
||||
|
||||
statusSelectQuery sq.SelectBuilder
|
||||
}
|
||||
|
||||
func newSqlStatusStore(sqlStore *SqlStore) store.StatusStore {
|
||||
return &SqlStatusStore{sqlStore}
|
||||
s := SqlStatusStore{
|
||||
SqlStore: sqlStore,
|
||||
}
|
||||
|
||||
manualColumnName := quoteColumnName(s.DriverName(), "Manual")
|
||||
s.statusSelectQuery = s.getQueryBuilder().
|
||||
Select(
|
||||
"COALESCE(UserId, '') AS UserId",
|
||||
"COALESCE(Status, '') AS Status",
|
||||
fmt.Sprintf("COALESCE(%s, FALSE) AS %s", manualColumnName, manualColumnName),
|
||||
"COALESCE(LastActivityAt, 0) AS LastActivityAt",
|
||||
"COALESCE(DNDEndTime, 0) AS DNDEndTime",
|
||||
"COALESCE(PrevStatus, '') AS PrevStatus",
|
||||
).
|
||||
From("Status")
|
||||
|
||||
return &s
|
||||
}
|
||||
|
||||
func (s SqlStatusStore) SaveOrUpdate(st *model.Status) error {
|
||||
@ -37,12 +55,7 @@ func (s SqlStatusStore) SaveOrUpdate(st *model.Status) error {
|
||||
st.Status, st.Manual, st.LastActivityAt, st.DNDEndTime, st.PrevStatus))
|
||||
}
|
||||
|
||||
queryString, args, err := query.ToSql()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "status_tosql")
|
||||
}
|
||||
|
||||
if _, err := s.GetMaster().Exec(queryString, args...); err != nil {
|
||||
if _, err := s.GetMaster().ExecBuilder(query); err != nil {
|
||||
return errors.Wrap(err, "failed to upsert Status")
|
||||
}
|
||||
|
||||
@ -50,9 +63,10 @@ func (s SqlStatusStore) SaveOrUpdate(st *model.Status) error {
|
||||
}
|
||||
|
||||
func (s SqlStatusStore) Get(userId string) (*model.Status, error) {
|
||||
var status model.Status
|
||||
query := s.statusSelectQuery.Where(sq.Eq{"UserId": userId})
|
||||
|
||||
if err := s.GetReplica().Get(&status, "SELECT * FROM Status WHERE UserId = ?", userId); err != nil {
|
||||
var status model.Status
|
||||
if err := s.GetReplica().GetBuilder(&status, query); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, store.NewErrNotFound("Status", fmt.Sprintf("userId=%s", userId))
|
||||
}
|
||||
@ -62,30 +76,13 @@ func (s SqlStatusStore) Get(userId string) (*model.Status, error) {
|
||||
}
|
||||
|
||||
func (s SqlStatusStore) GetByIds(userIds []string) ([]*model.Status, error) {
|
||||
query := s.getQueryBuilder().
|
||||
Select(fmt.Sprintf("UserId, Status, %s, LastActivityAt", quoteColumnName(s.DriverName(), "Manual"))).
|
||||
From("Status").
|
||||
Where(sq.Eq{"UserId": userIds})
|
||||
queryString, args, err := query.ToSql()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "status_tosql")
|
||||
}
|
||||
rows, err := s.GetReplica().DB.Query(queryString, args...)
|
||||
query := s.statusSelectQuery.Where(sq.Eq{"UserId": userIds})
|
||||
|
||||
statuses := []*model.Status{}
|
||||
err := s.GetReplica().SelectBuilder(&statuses, query)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to find Statuses")
|
||||
}
|
||||
statuses := []*model.Status{}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var status model.Status
|
||||
if err = rows.Scan(&status.UserId, &status.Status, &status.Manual, &status.LastActivityAt); err != nil {
|
||||
return nil, errors.Wrap(err, "unable to scan from rows")
|
||||
}
|
||||
statuses = append(statuses, &status)
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, errors.Wrap(err, "failed while iterating over rows")
|
||||
}
|
||||
|
||||
return statuses, nil
|
||||
}
|
||||
@ -94,24 +91,19 @@ func (s SqlStatusStore) GetByIds(userIds []string) ([]*model.Status, error) {
|
||||
func (s SqlStatusStore) updateExpiredStatuses(t *sqlxTxWrapper) ([]*model.Status, error) {
|
||||
statuses := []*model.Status{}
|
||||
currUnixTime := time.Now().UTC().Unix()
|
||||
selectQuery, selectParams, err := s.getQueryBuilder().
|
||||
Select("*").
|
||||
From("Status").
|
||||
Where(
|
||||
sq.And{
|
||||
sq.Eq{"Status": model.StatusDnd},
|
||||
sq.Gt{"DNDEndTime": 0},
|
||||
sq.LtOrEq{"DNDEndTime": currUnixTime},
|
||||
},
|
||||
).ToSql()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "status_tosql")
|
||||
}
|
||||
err = t.Select(&statuses, selectQuery, selectParams...)
|
||||
selectQuery := s.statusSelectQuery.Where(
|
||||
sq.And{
|
||||
sq.Eq{"Status": model.StatusDnd},
|
||||
sq.Gt{"DNDEndTime": 0},
|
||||
sq.LtOrEq{"DNDEndTime": currUnixTime},
|
||||
},
|
||||
)
|
||||
|
||||
err := t.SelectBuilder(&statuses, selectQuery)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "updateExpiredStatusesT: failed to get expired dnd statuses")
|
||||
}
|
||||
updateQuery, args, err := s.getQueryBuilder().
|
||||
updateQuery := s.getQueryBuilder().
|
||||
Update("Status").
|
||||
Where(
|
||||
sq.And{
|
||||
@ -123,14 +115,9 @@ func (s SqlStatusStore) updateExpiredStatuses(t *sqlxTxWrapper) ([]*model.Status
|
||||
Set("Status", sq.Expr("PrevStatus")).
|
||||
Set("PrevStatus", model.StatusDnd).
|
||||
Set("DNDEndTime", 0).
|
||||
Set(quoteColumnName(s.DriverName(), "Manual"), false).
|
||||
ToSql()
|
||||
Set(quoteColumnName(s.DriverName(), "Manual"), false)
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "status_tosql")
|
||||
}
|
||||
|
||||
if _, err := t.Exec(updateQuery, args...); err != nil {
|
||||
if _, err := t.ExecBuilder(updateQuery); err != nil {
|
||||
return nil, errors.Wrapf(err, "updateExpiredStatusesT: failed to update statuses")
|
||||
}
|
||||
|
||||
@ -162,7 +149,7 @@ func (s SqlStatusStore) UpdateExpiredDNDStatuses() (_ []*model.Status, err error
|
||||
return statuses, nil
|
||||
}
|
||||
|
||||
queryString, args, err := s.getQueryBuilder().
|
||||
queryString := s.getQueryBuilder().
|
||||
Update("Status").
|
||||
Where(
|
||||
sq.And{
|
||||
@ -175,30 +162,13 @@ func (s SqlStatusStore) UpdateExpiredDNDStatuses() (_ []*model.Status, err error
|
||||
Set("PrevStatus", model.StatusDnd).
|
||||
Set("DNDEndTime", 0).
|
||||
Set(quoteColumnName(s.DriverName(), "Manual"), false).
|
||||
Suffix("RETURNING *").
|
||||
ToSql()
|
||||
Suffix("RETURNING *")
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "status_tosql")
|
||||
}
|
||||
|
||||
rows, err := s.GetMaster().Query(queryString, args...)
|
||||
statuses := []*model.Status{}
|
||||
err = s.GetMaster().SelectBuilder(&statuses, queryString)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to find Statuses")
|
||||
}
|
||||
defer rows.Close()
|
||||
statuses := []*model.Status{}
|
||||
for rows.Next() {
|
||||
var status model.Status
|
||||
if err = rows.Scan(&status.UserId, &status.Status, &status.Manual, &status.LastActivityAt,
|
||||
&status.DNDEndTime, &status.PrevStatus); err != nil {
|
||||
return nil, errors.Wrap(err, "unable to scan from rows")
|
||||
}
|
||||
statuses = append(statuses, &status)
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, errors.Wrap(err, "failed while iterating over rows")
|
||||
}
|
||||
|
||||
return statuses, nil
|
||||
}
|
||||
@ -212,8 +182,13 @@ func (s SqlStatusStore) ResetAll() error {
|
||||
|
||||
func (s SqlStatusStore) GetTotalActiveUsersCount() (int64, error) {
|
||||
time := model.GetMillis() - (1000 * 60 * 60 * 24)
|
||||
query := s.getQueryBuilder().
|
||||
Select("COUNT(UserId)").
|
||||
From("Status").
|
||||
Where(sq.Gt{"LastActivityAt": time})
|
||||
|
||||
var count int64
|
||||
err := s.GetReplica().Get(&count, "SELECT COUNT(UserId) FROM Status WHERE LastActivityAt > ?", time)
|
||||
err := s.GetReplica().GetBuilder(&count, query)
|
||||
if err != nil {
|
||||
return count, errors.Wrap(err, "failed to count active users")
|
||||
}
|
||||
@ -221,7 +196,12 @@ func (s SqlStatusStore) GetTotalActiveUsersCount() (int64, error) {
|
||||
}
|
||||
|
||||
func (s SqlStatusStore) UpdateLastActivityAt(userId string, lastActivityAt int64) error {
|
||||
if _, err := s.GetMaster().Exec("UPDATE Status SET LastActivityAt = ? WHERE UserId = ?", lastActivityAt, userId); err != nil {
|
||||
builder := s.getQueryBuilder().
|
||||
Update("Status").
|
||||
Set("LastActivityAt", lastActivityAt).
|
||||
Where(sq.Eq{"UserId": userId})
|
||||
|
||||
if _, err := s.GetMaster().ExecBuilder(builder); err != nil {
|
||||
return errors.Wrapf(err, "failed to update last activity for userId=%s", userId)
|
||||
}
|
||||
|
||||
|
@ -10,5 +10,5 @@ import (
|
||||
)
|
||||
|
||||
func TestStatusStore(t *testing.T) {
|
||||
StoreTest(t, storetest.TestStatusStore)
|
||||
StoreTestWithSqlStore(t, storetest.TestStatusStore)
|
||||
}
|
||||
|
@ -24,11 +24,13 @@ import (
|
||||
"github.com/mattermost/mattermost/server/public/shared/timezones"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/store"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/utils"
|
||||
sq "github.com/mattermost/squirrel"
|
||||
)
|
||||
|
||||
type SqlStore interface {
|
||||
GetMaster() SqlXExecutor
|
||||
DriverName() string
|
||||
GetQueryPlaceholder() sq.PlaceholderFormat
|
||||
}
|
||||
|
||||
type SqlXExecutor interface {
|
||||
@ -5766,6 +5768,7 @@ func (s ByChannelDisplayName) Len() int { return len(s) }
|
||||
func (s ByChannelDisplayName) Swap(i, j int) {
|
||||
s[i], s[j] = s[j], s[i]
|
||||
}
|
||||
|
||||
func (s ByChannelDisplayName) Less(i, j int) bool {
|
||||
if s[i].DisplayName != s[j].DisplayName {
|
||||
return s[i].DisplayName < s[j].DisplayName
|
||||
|
@ -7,6 +7,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
sq "github.com/mattermost/squirrel"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/mattermost/mattermost/server/public/model"
|
||||
@ -14,13 +16,15 @@ import (
|
||||
"github.com/mattermost/mattermost/server/v8/channels/store"
|
||||
)
|
||||
|
||||
func TestStatusStore(t *testing.T, rctx request.CTX, ss store.Store) {
|
||||
func TestStatusStore(t *testing.T, rctx request.CTX, ss store.Store, s SqlStore) {
|
||||
t.Run("", func(t *testing.T) { testStatusStore(t, rctx, ss) })
|
||||
t.Run("ActiveUserCount", func(t *testing.T) { testActiveUserCount(t, rctx, ss) })
|
||||
t.Run("UpdateExpiredDNDStatuses", func(t *testing.T) { testUpdateExpiredDNDStatuses(t, rctx, ss) })
|
||||
t.Run("Get", func(t *testing.T) { testStatusGet(t, rctx, ss, s) })
|
||||
t.Run("GetByIds", func(t *testing.T) { testStatusGetByIds(t, rctx, ss, s) })
|
||||
}
|
||||
|
||||
func testStatusStore(t *testing.T, rctx request.CTX, ss store.Store) {
|
||||
func testStatusStore(t *testing.T, _ request.CTX, ss store.Store) {
|
||||
status := &model.Status{UserId: model.NewId(), Status: model.StatusOnline, Manual: false, LastActivityAt: 0, ActiveChannel: ""}
|
||||
require.NoError(t, ss.Status().SaveOrUpdate(status))
|
||||
|
||||
@ -51,12 +55,19 @@ func testStatusStore(t *testing.T, rctx request.CTX, ss store.Store) {
|
||||
}
|
||||
|
||||
func testActiveUserCount(t *testing.T, rctx request.CTX, ss store.Store) {
|
||||
status := &model.Status{UserId: model.NewId(), Status: model.StatusOnline, Manual: false, LastActivityAt: model.GetMillis(), ActiveChannel: ""}
|
||||
require.NoError(t, ss.Status().SaveOrUpdate(status))
|
||||
status1 := &model.Status{UserId: model.NewId(), Status: model.StatusOnline, Manual: false, LastActivityAt: model.GetMillis(), ActiveChannel: ""}
|
||||
require.NoError(t, ss.Status().SaveOrUpdate(status1))
|
||||
|
||||
count, err := ss.Status().GetTotalActiveUsersCount()
|
||||
count1, err := ss.Status().GetTotalActiveUsersCount()
|
||||
require.NoError(t, err)
|
||||
require.True(t, count > 0, "expected count > 0, got %d", count)
|
||||
assert.Greater(t, count1, int64(0))
|
||||
|
||||
status2 := &model.Status{UserId: model.NewId(), Status: model.StatusOnline, Manual: false, LastActivityAt: model.GetMillis(), ActiveChannel: ""}
|
||||
require.NoError(t, ss.Status().SaveOrUpdate(status2))
|
||||
|
||||
count2, err := ss.Status().GetTotalActiveUsersCount()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, count1+1, count2)
|
||||
}
|
||||
|
||||
type ByUserId []*model.Status
|
||||
@ -68,8 +79,15 @@ func (s ByUserId) Less(i, j int) bool { return s[i].UserId < s[j].UserId }
|
||||
func testUpdateExpiredDNDStatuses(t *testing.T, rctx request.CTX, ss store.Store) {
|
||||
userID := NewTestID()
|
||||
|
||||
status := &model.Status{UserId: userID, Status: model.StatusDnd, Manual: true,
|
||||
DNDEndTime: time.Now().Add(5 * time.Second).Unix(), PrevStatus: model.StatusOnline}
|
||||
status := &model.Status{
|
||||
UserId: userID,
|
||||
Status: model.StatusDnd,
|
||||
Manual: true,
|
||||
LastActivityAt: time.Now().Unix(),
|
||||
ActiveChannel: "channel-id",
|
||||
DNDEndTime: time.Now().Add(5 * time.Second).Unix(),
|
||||
PrevStatus: model.StatusOnline,
|
||||
}
|
||||
require.NoError(t, ss.Status().SaveOrUpdate(status))
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
@ -87,9 +105,181 @@ func testUpdateExpiredDNDStatuses(t *testing.T, rctx request.CTX, ss store.Store
|
||||
require.Len(t, statuses, 1)
|
||||
|
||||
updatedStatus := *statuses[0]
|
||||
require.Equal(t, updatedStatus.UserId, userID)
|
||||
require.Equal(t, updatedStatus.Status, model.StatusOnline)
|
||||
require.Equal(t, updatedStatus.DNDEndTime, int64(0))
|
||||
require.Equal(t, updatedStatus.PrevStatus, model.StatusDnd)
|
||||
require.Equal(t, updatedStatus.Manual, false)
|
||||
assert.Equal(t, updatedStatus.UserId, userID)
|
||||
assert.Equal(t, updatedStatus.Status, model.StatusOnline)
|
||||
assert.Equal(t, updatedStatus.Manual, false)
|
||||
assert.Equal(t, updatedStatus.LastActivityAt, updatedStatus.LastActivityAt)
|
||||
assert.Empty(t, updatedStatus.ActiveChannel)
|
||||
assert.Equal(t, updatedStatus.DNDEndTime, int64(0))
|
||||
assert.Equal(t, updatedStatus.PrevStatus, model.StatusDnd)
|
||||
}
|
||||
|
||||
func insertNullStatus(t *testing.T, ss store.Store, s SqlStore) string {
|
||||
userId := model.NewId()
|
||||
db := ss.GetInternalMasterDB()
|
||||
|
||||
// Insert status with explicit NULL values
|
||||
builder := sq.StatementBuilder.PlaceholderFormat(s.GetQueryPlaceholder()).
|
||||
Insert("Status").
|
||||
Columns("UserId", "Status", quoteColumnName(s.DriverName(), "Manual"), "LastActivityAt", "DNDEndTime", "PrevStatus").
|
||||
Values(userId, nil, nil, nil, nil, nil)
|
||||
|
||||
query, args, err := builder.ToSql()
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.Exec(query, args...)
|
||||
require.NoError(t, err)
|
||||
|
||||
return userId
|
||||
}
|
||||
|
||||
func testStatusGet(t *testing.T, _ request.CTX, ss store.Store, s SqlStore) {
|
||||
t.Run("null columns", func(t *testing.T) {
|
||||
userId := insertNullStatus(t, ss, s)
|
||||
|
||||
received, err := ss.Status().Get(userId)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userId, received.UserId)
|
||||
assert.Empty(t, received.Status)
|
||||
assert.False(t, received.Manual)
|
||||
assert.Equal(t, int64(0), received.LastActivityAt)
|
||||
assert.Empty(t, received.ActiveChannel)
|
||||
assert.Equal(t, int64(0), received.DNDEndTime)
|
||||
assert.Empty(t, received.PrevStatus)
|
||||
})
|
||||
|
||||
t.Run("status1", func(t *testing.T) {
|
||||
status1 := &model.Status{
|
||||
UserId: model.NewId(),
|
||||
Status: model.StatusDnd,
|
||||
Manual: true,
|
||||
LastActivityAt: 1234,
|
||||
ActiveChannel: "channel-id",
|
||||
DNDEndTime: model.GetMillis(),
|
||||
PrevStatus: model.StatusOnline,
|
||||
}
|
||||
require.NoError(t, ss.Status().SaveOrUpdate(status1))
|
||||
|
||||
received, err := ss.Status().Get(status1.UserId)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, status1.UserId, received.UserId)
|
||||
assert.Equal(t, status1.Status, received.Status)
|
||||
assert.Equal(t, status1.Manual, received.Manual)
|
||||
assert.Equal(t, status1.LastActivityAt, received.LastActivityAt)
|
||||
assert.Empty(t, received.ActiveChannel)
|
||||
assert.Equal(t, status1.DNDEndTime, received.DNDEndTime)
|
||||
assert.Equal(t, status1.PrevStatus, received.PrevStatus)
|
||||
})
|
||||
|
||||
t.Run("status2", func(t *testing.T) {
|
||||
status2 := &model.Status{
|
||||
UserId: model.NewId(),
|
||||
Status: model.StatusOffline,
|
||||
Manual: false,
|
||||
LastActivityAt: 12345,
|
||||
ActiveChannel: "channel-id2",
|
||||
DNDEndTime: model.GetMillis(),
|
||||
PrevStatus: model.StatusAway,
|
||||
}
|
||||
require.NoError(t, ss.Status().SaveOrUpdate(status2))
|
||||
|
||||
received, err := ss.Status().Get(status2.UserId)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, status2.UserId, received.UserId)
|
||||
assert.Equal(t, status2.Status, received.Status)
|
||||
assert.Equal(t, status2.Manual, received.Manual)
|
||||
assert.Equal(t, status2.LastActivityAt, received.LastActivityAt)
|
||||
assert.Empty(t, received.ActiveChannel)
|
||||
assert.Equal(t, status2.DNDEndTime, received.DNDEndTime)
|
||||
assert.Equal(t, status2.PrevStatus, received.PrevStatus)
|
||||
})
|
||||
}
|
||||
|
||||
func testStatusGetByIds(t *testing.T, _ request.CTX, ss store.Store, s SqlStore) {
|
||||
t.Run("null columns, single user", func(t *testing.T) {
|
||||
userId := insertNullStatus(t, ss, s)
|
||||
|
||||
received, err := ss.Status().GetByIds([]string{userId})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, received, 1)
|
||||
assert.Equal(t, userId, received[0].UserId)
|
||||
assert.Empty(t, received[0].Status)
|
||||
assert.False(t, received[0].Manual)
|
||||
assert.Equal(t, int64(0), received[0].LastActivityAt)
|
||||
assert.Empty(t, received[0].ActiveChannel)
|
||||
assert.Equal(t, int64(0), received[0].DNDEndTime)
|
||||
assert.Empty(t, received[0].PrevStatus)
|
||||
})
|
||||
|
||||
t.Run("single user", func(t *testing.T) {
|
||||
status1 := &model.Status{
|
||||
UserId: model.NewId(),
|
||||
Status: model.StatusDnd,
|
||||
Manual: true,
|
||||
LastActivityAt: 1234,
|
||||
ActiveChannel: "channel-id",
|
||||
DNDEndTime: model.GetMillis(),
|
||||
PrevStatus: model.StatusOnline,
|
||||
}
|
||||
require.NoError(t, ss.Status().SaveOrUpdate(status1))
|
||||
|
||||
received, err := ss.Status().GetByIds([]string{status1.UserId})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, received, 1)
|
||||
assert.Equal(t, status1.UserId, received[0].UserId)
|
||||
assert.Equal(t, status1.Status, received[0].Status)
|
||||
assert.Equal(t, status1.Manual, received[0].Manual)
|
||||
assert.Equal(t, status1.LastActivityAt, received[0].LastActivityAt)
|
||||
assert.Empty(t, received[0].ActiveChannel)
|
||||
assert.Equal(t, status1.DNDEndTime, received[0].DNDEndTime)
|
||||
assert.Equal(t, status1.PrevStatus, received[0].PrevStatus)
|
||||
})
|
||||
|
||||
t.Run("multiple users", func(t *testing.T) {
|
||||
status1 := &model.Status{
|
||||
UserId: model.NewId(),
|
||||
Status: model.StatusDnd,
|
||||
Manual: true,
|
||||
LastActivityAt: 1234,
|
||||
ActiveChannel: "channel-id",
|
||||
DNDEndTime: model.GetMillis(),
|
||||
PrevStatus: model.StatusOnline,
|
||||
}
|
||||
require.NoError(t, ss.Status().SaveOrUpdate(status1))
|
||||
|
||||
status2 := &model.Status{
|
||||
UserId: model.NewId(),
|
||||
Status: model.StatusOffline,
|
||||
Manual: false,
|
||||
LastActivityAt: 12345,
|
||||
ActiveChannel: "channel-id2",
|
||||
DNDEndTime: model.GetMillis(),
|
||||
PrevStatus: model.StatusAway,
|
||||
}
|
||||
require.NoError(t, ss.Status().SaveOrUpdate(status2))
|
||||
|
||||
received, err := ss.Status().GetByIds([]string{status1.UserId, status2.UserId})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, received, 2)
|
||||
|
||||
for _, status := range received {
|
||||
if status.UserId == status1.UserId {
|
||||
assert.Equal(t, status1.UserId, status.UserId)
|
||||
assert.Equal(t, status1.Status, status.Status)
|
||||
assert.Equal(t, status1.Manual, status.Manual)
|
||||
assert.Equal(t, status1.LastActivityAt, status.LastActivityAt)
|
||||
assert.Empty(t, status.ActiveChannel)
|
||||
assert.Equal(t, status1.DNDEndTime, status.DNDEndTime)
|
||||
assert.Equal(t, status1.PrevStatus, status.PrevStatus)
|
||||
} else {
|
||||
assert.Equal(t, status2.UserId, status.UserId)
|
||||
assert.Equal(t, status2.Status, status.Status)
|
||||
assert.Equal(t, status2.Manual, status.Manual)
|
||||
assert.Equal(t, status2.LastActivityAt, status.LastActivityAt)
|
||||
assert.Empty(t, status.ActiveChannel)
|
||||
assert.Equal(t, status2.DNDEndTime, status.DNDEndTime)
|
||||
assert.Equal(t, status2.PrevStatus, status.PrevStatus)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -4,6 +4,8 @@
|
||||
package storetest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mattermost/mattermost/server/public/model"
|
||||
)
|
||||
|
||||
@ -19,3 +21,16 @@ func NewTestID() string {
|
||||
|
||||
return string(newID)
|
||||
}
|
||||
|
||||
// Adds backtiks to the column name for MySQL, this is required if
|
||||
// the column name is a reserved keyword.
|
||||
//
|
||||
// `ColumnName` - MySQL
|
||||
// ColumnName - Postgres
|
||||
func quoteColumnName(driver string, columnName string) string {
|
||||
if driver == model.DatabaseDriverMysql {
|
||||
return fmt.Sprintf("`%s`", columnName)
|
||||
}
|
||||
|
||||
return columnName
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user