mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
MM-62142: remove SELECT * from status store (round 2) (#30060)
* Revert "Revert "[MM-62142] Avoid SELECT * in status_store.go (#29610)" (#29985)"
This reverts commit d345e92136
.
* add tests for StatusGet and StatusGetByIds
* handle NULL columns in the Status Store
* simplify status store
* more builder simplifications and tests
* expose GetQueryPlaceholder
---------
Co-authored-by: Mattermost Build <build@mattermost.com>
This commit is contained in:
parent
c4718e4542
commit
7d9521d783
@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/mattermost/mattermost/server/public/model"
|
"github.com/mattermost/mattermost/server/public/model"
|
||||||
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
||||||
"github.com/mattermost/mattermost/server/v8/channels/store/storetest"
|
"github.com/mattermost/mattermost/server/v8/channels/store/storetest"
|
||||||
|
sq "github.com/mattermost/squirrel"
|
||||||
)
|
)
|
||||||
|
|
||||||
type StoreTestWrapper struct {
|
type StoreTestWrapper struct {
|
||||||
@ -38,6 +39,10 @@ func (w *StoreTestWrapper) DriverName() string {
|
|||||||
return w.orig.DriverName()
|
return w.orig.DriverName()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *StoreTestWrapper) GetQueryPlaceholder() sq.PlaceholderFormat {
|
||||||
|
return w.orig.getQueryPlaceholder()
|
||||||
|
}
|
||||||
|
|
||||||
type Builder interface {
|
type Builder interface {
|
||||||
ToSql() (string, []any, error)
|
ToSql() (string, []any, error)
|
||||||
}
|
}
|
||||||
|
@ -17,10 +17,28 @@ import (
|
|||||||
|
|
||||||
type SqlStatusStore struct {
|
type SqlStatusStore struct {
|
||||||
*SqlStore
|
*SqlStore
|
||||||
|
|
||||||
|
statusSelectQuery sq.SelectBuilder
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSqlStatusStore(sqlStore *SqlStore) store.StatusStore {
|
func newSqlStatusStore(sqlStore *SqlStore) store.StatusStore {
|
||||||
return &SqlStatusStore{sqlStore}
|
s := SqlStatusStore{
|
||||||
|
SqlStore: sqlStore,
|
||||||
|
}
|
||||||
|
|
||||||
|
manualColumnName := quoteColumnName(s.DriverName(), "Manual")
|
||||||
|
s.statusSelectQuery = s.getQueryBuilder().
|
||||||
|
Select(
|
||||||
|
"COALESCE(UserId, '') AS UserId",
|
||||||
|
"COALESCE(Status, '') AS Status",
|
||||||
|
fmt.Sprintf("COALESCE(%s, FALSE) AS %s", manualColumnName, manualColumnName),
|
||||||
|
"COALESCE(LastActivityAt, 0) AS LastActivityAt",
|
||||||
|
"COALESCE(DNDEndTime, 0) AS DNDEndTime",
|
||||||
|
"COALESCE(PrevStatus, '') AS PrevStatus",
|
||||||
|
).
|
||||||
|
From("Status")
|
||||||
|
|
||||||
|
return &s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s SqlStatusStore) SaveOrUpdate(st *model.Status) error {
|
func (s SqlStatusStore) SaveOrUpdate(st *model.Status) error {
|
||||||
@ -37,12 +55,7 @@ func (s SqlStatusStore) SaveOrUpdate(st *model.Status) error {
|
|||||||
st.Status, st.Manual, st.LastActivityAt, st.DNDEndTime, st.PrevStatus))
|
st.Status, st.Manual, st.LastActivityAt, st.DNDEndTime, st.PrevStatus))
|
||||||
}
|
}
|
||||||
|
|
||||||
queryString, args, err := query.ToSql()
|
if _, err := s.GetMaster().ExecBuilder(query); err != nil {
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "status_tosql")
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := s.GetMaster().Exec(queryString, args...); err != nil {
|
|
||||||
return errors.Wrap(err, "failed to upsert Status")
|
return errors.Wrap(err, "failed to upsert Status")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,9 +63,10 @@ func (s SqlStatusStore) SaveOrUpdate(st *model.Status) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s SqlStatusStore) Get(userId string) (*model.Status, error) {
|
func (s SqlStatusStore) Get(userId string) (*model.Status, error) {
|
||||||
var status model.Status
|
query := s.statusSelectQuery.Where(sq.Eq{"UserId": userId})
|
||||||
|
|
||||||
if err := s.GetReplica().Get(&status, "SELECT * FROM Status WHERE UserId = ?", userId); err != nil {
|
var status model.Status
|
||||||
|
if err := s.GetReplica().GetBuilder(&status, query); err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, store.NewErrNotFound("Status", fmt.Sprintf("userId=%s", userId))
|
return nil, store.NewErrNotFound("Status", fmt.Sprintf("userId=%s", userId))
|
||||||
}
|
}
|
||||||
@ -62,30 +76,13 @@ func (s SqlStatusStore) Get(userId string) (*model.Status, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s SqlStatusStore) GetByIds(userIds []string) ([]*model.Status, error) {
|
func (s SqlStatusStore) GetByIds(userIds []string) ([]*model.Status, error) {
|
||||||
query := s.getQueryBuilder().
|
query := s.statusSelectQuery.Where(sq.Eq{"UserId": userIds})
|
||||||
Select(fmt.Sprintf("UserId, Status, %s, LastActivityAt", quoteColumnName(s.DriverName(), "Manual"))).
|
|
||||||
From("Status").
|
statuses := []*model.Status{}
|
||||||
Where(sq.Eq{"UserId": userIds})
|
err := s.GetReplica().SelectBuilder(&statuses, query)
|
||||||
queryString, args, err := query.ToSql()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "status_tosql")
|
|
||||||
}
|
|
||||||
rows, err := s.GetReplica().DB.Query(queryString, args...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to find Statuses")
|
return nil, errors.Wrap(err, "failed to find Statuses")
|
||||||
}
|
}
|
||||||
statuses := []*model.Status{}
|
|
||||||
defer rows.Close()
|
|
||||||
for rows.Next() {
|
|
||||||
var status model.Status
|
|
||||||
if err = rows.Scan(&status.UserId, &status.Status, &status.Manual, &status.LastActivityAt); err != nil {
|
|
||||||
return nil, errors.Wrap(err, "unable to scan from rows")
|
|
||||||
}
|
|
||||||
statuses = append(statuses, &status)
|
|
||||||
}
|
|
||||||
if err = rows.Err(); err != nil {
|
|
||||||
return nil, errors.Wrap(err, "failed while iterating over rows")
|
|
||||||
}
|
|
||||||
|
|
||||||
return statuses, nil
|
return statuses, nil
|
||||||
}
|
}
|
||||||
@ -94,24 +91,19 @@ func (s SqlStatusStore) GetByIds(userIds []string) ([]*model.Status, error) {
|
|||||||
func (s SqlStatusStore) updateExpiredStatuses(t *sqlxTxWrapper) ([]*model.Status, error) {
|
func (s SqlStatusStore) updateExpiredStatuses(t *sqlxTxWrapper) ([]*model.Status, error) {
|
||||||
statuses := []*model.Status{}
|
statuses := []*model.Status{}
|
||||||
currUnixTime := time.Now().UTC().Unix()
|
currUnixTime := time.Now().UTC().Unix()
|
||||||
selectQuery, selectParams, err := s.getQueryBuilder().
|
selectQuery := s.statusSelectQuery.Where(
|
||||||
Select("*").
|
sq.And{
|
||||||
From("Status").
|
sq.Eq{"Status": model.StatusDnd},
|
||||||
Where(
|
sq.Gt{"DNDEndTime": 0},
|
||||||
sq.And{
|
sq.LtOrEq{"DNDEndTime": currUnixTime},
|
||||||
sq.Eq{"Status": model.StatusDnd},
|
},
|
||||||
sq.Gt{"DNDEndTime": 0},
|
)
|
||||||
sq.LtOrEq{"DNDEndTime": currUnixTime},
|
|
||||||
},
|
err := t.SelectBuilder(&statuses, selectQuery)
|
||||||
).ToSql()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "status_tosql")
|
|
||||||
}
|
|
||||||
err = t.Select(&statuses, selectQuery, selectParams...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "updateExpiredStatusesT: failed to get expired dnd statuses")
|
return nil, errors.Wrap(err, "updateExpiredStatusesT: failed to get expired dnd statuses")
|
||||||
}
|
}
|
||||||
updateQuery, args, err := s.getQueryBuilder().
|
updateQuery := s.getQueryBuilder().
|
||||||
Update("Status").
|
Update("Status").
|
||||||
Where(
|
Where(
|
||||||
sq.And{
|
sq.And{
|
||||||
@ -123,14 +115,9 @@ func (s SqlStatusStore) updateExpiredStatuses(t *sqlxTxWrapper) ([]*model.Status
|
|||||||
Set("Status", sq.Expr("PrevStatus")).
|
Set("Status", sq.Expr("PrevStatus")).
|
||||||
Set("PrevStatus", model.StatusDnd).
|
Set("PrevStatus", model.StatusDnd).
|
||||||
Set("DNDEndTime", 0).
|
Set("DNDEndTime", 0).
|
||||||
Set(quoteColumnName(s.DriverName(), "Manual"), false).
|
Set(quoteColumnName(s.DriverName(), "Manual"), false)
|
||||||
ToSql()
|
|
||||||
|
|
||||||
if err != nil {
|
if _, err := t.ExecBuilder(updateQuery); err != nil {
|
||||||
return nil, errors.Wrap(err, "status_tosql")
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := t.Exec(updateQuery, args...); err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "updateExpiredStatusesT: failed to update statuses")
|
return nil, errors.Wrapf(err, "updateExpiredStatusesT: failed to update statuses")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -162,7 +149,7 @@ func (s SqlStatusStore) UpdateExpiredDNDStatuses() (_ []*model.Status, err error
|
|||||||
return statuses, nil
|
return statuses, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
queryString, args, err := s.getQueryBuilder().
|
queryString := s.getQueryBuilder().
|
||||||
Update("Status").
|
Update("Status").
|
||||||
Where(
|
Where(
|
||||||
sq.And{
|
sq.And{
|
||||||
@ -175,30 +162,13 @@ func (s SqlStatusStore) UpdateExpiredDNDStatuses() (_ []*model.Status, err error
|
|||||||
Set("PrevStatus", model.StatusDnd).
|
Set("PrevStatus", model.StatusDnd).
|
||||||
Set("DNDEndTime", 0).
|
Set("DNDEndTime", 0).
|
||||||
Set(quoteColumnName(s.DriverName(), "Manual"), false).
|
Set(quoteColumnName(s.DriverName(), "Manual"), false).
|
||||||
Suffix("RETURNING *").
|
Suffix("RETURNING *")
|
||||||
ToSql()
|
|
||||||
|
|
||||||
if err != nil {
|
statuses := []*model.Status{}
|
||||||
return nil, errors.Wrap(err, "status_tosql")
|
err = s.GetMaster().SelectBuilder(&statuses, queryString)
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := s.GetMaster().Query(queryString, args...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to find Statuses")
|
return nil, errors.Wrap(err, "failed to find Statuses")
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
statuses := []*model.Status{}
|
|
||||||
for rows.Next() {
|
|
||||||
var status model.Status
|
|
||||||
if err = rows.Scan(&status.UserId, &status.Status, &status.Manual, &status.LastActivityAt,
|
|
||||||
&status.DNDEndTime, &status.PrevStatus); err != nil {
|
|
||||||
return nil, errors.Wrap(err, "unable to scan from rows")
|
|
||||||
}
|
|
||||||
statuses = append(statuses, &status)
|
|
||||||
}
|
|
||||||
if err = rows.Err(); err != nil {
|
|
||||||
return nil, errors.Wrap(err, "failed while iterating over rows")
|
|
||||||
}
|
|
||||||
|
|
||||||
return statuses, nil
|
return statuses, nil
|
||||||
}
|
}
|
||||||
@ -212,8 +182,13 @@ func (s SqlStatusStore) ResetAll() error {
|
|||||||
|
|
||||||
func (s SqlStatusStore) GetTotalActiveUsersCount() (int64, error) {
|
func (s SqlStatusStore) GetTotalActiveUsersCount() (int64, error) {
|
||||||
time := model.GetMillis() - (1000 * 60 * 60 * 24)
|
time := model.GetMillis() - (1000 * 60 * 60 * 24)
|
||||||
|
query := s.getQueryBuilder().
|
||||||
|
Select("COUNT(UserId)").
|
||||||
|
From("Status").
|
||||||
|
Where(sq.Gt{"LastActivityAt": time})
|
||||||
|
|
||||||
var count int64
|
var count int64
|
||||||
err := s.GetReplica().Get(&count, "SELECT COUNT(UserId) FROM Status WHERE LastActivityAt > ?", time)
|
err := s.GetReplica().GetBuilder(&count, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return count, errors.Wrap(err, "failed to count active users")
|
return count, errors.Wrap(err, "failed to count active users")
|
||||||
}
|
}
|
||||||
@ -221,7 +196,12 @@ func (s SqlStatusStore) GetTotalActiveUsersCount() (int64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s SqlStatusStore) UpdateLastActivityAt(userId string, lastActivityAt int64) error {
|
func (s SqlStatusStore) UpdateLastActivityAt(userId string, lastActivityAt int64) error {
|
||||||
if _, err := s.GetMaster().Exec("UPDATE Status SET LastActivityAt = ? WHERE UserId = ?", lastActivityAt, userId); err != nil {
|
builder := s.getQueryBuilder().
|
||||||
|
Update("Status").
|
||||||
|
Set("LastActivityAt", lastActivityAt).
|
||||||
|
Where(sq.Eq{"UserId": userId})
|
||||||
|
|
||||||
|
if _, err := s.GetMaster().ExecBuilder(builder); err != nil {
|
||||||
return errors.Wrapf(err, "failed to update last activity for userId=%s", userId)
|
return errors.Wrapf(err, "failed to update last activity for userId=%s", userId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10,5 +10,5 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestStatusStore(t *testing.T) {
|
func TestStatusStore(t *testing.T) {
|
||||||
StoreTest(t, storetest.TestStatusStore)
|
StoreTestWithSqlStore(t, storetest.TestStatusStore)
|
||||||
}
|
}
|
||||||
|
@ -24,11 +24,13 @@ import (
|
|||||||
"github.com/mattermost/mattermost/server/public/shared/timezones"
|
"github.com/mattermost/mattermost/server/public/shared/timezones"
|
||||||
"github.com/mattermost/mattermost/server/v8/channels/store"
|
"github.com/mattermost/mattermost/server/v8/channels/store"
|
||||||
"github.com/mattermost/mattermost/server/v8/channels/utils"
|
"github.com/mattermost/mattermost/server/v8/channels/utils"
|
||||||
|
sq "github.com/mattermost/squirrel"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SqlStore interface {
|
type SqlStore interface {
|
||||||
GetMaster() SqlXExecutor
|
GetMaster() SqlXExecutor
|
||||||
DriverName() string
|
DriverName() string
|
||||||
|
GetQueryPlaceholder() sq.PlaceholderFormat
|
||||||
}
|
}
|
||||||
|
|
||||||
type SqlXExecutor interface {
|
type SqlXExecutor interface {
|
||||||
@ -5766,6 +5768,7 @@ func (s ByChannelDisplayName) Len() int { return len(s) }
|
|||||||
func (s ByChannelDisplayName) Swap(i, j int) {
|
func (s ByChannelDisplayName) Swap(i, j int) {
|
||||||
s[i], s[j] = s[j], s[i]
|
s[i], s[j] = s[j], s[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s ByChannelDisplayName) Less(i, j int) bool {
|
func (s ByChannelDisplayName) Less(i, j int) bool {
|
||||||
if s[i].DisplayName != s[j].DisplayName {
|
if s[i].DisplayName != s[j].DisplayName {
|
||||||
return s[i].DisplayName < s[j].DisplayName
|
return s[i].DisplayName < s[j].DisplayName
|
||||||
|
@ -7,6 +7,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
sq "github.com/mattermost/squirrel"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/mattermost/mattermost/server/public/model"
|
"github.com/mattermost/mattermost/server/public/model"
|
||||||
@ -14,13 +16,15 @@ import (
|
|||||||
"github.com/mattermost/mattermost/server/v8/channels/store"
|
"github.com/mattermost/mattermost/server/v8/channels/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestStatusStore(t *testing.T, rctx request.CTX, ss store.Store) {
|
func TestStatusStore(t *testing.T, rctx request.CTX, ss store.Store, s SqlStore) {
|
||||||
t.Run("", func(t *testing.T) { testStatusStore(t, rctx, ss) })
|
t.Run("", func(t *testing.T) { testStatusStore(t, rctx, ss) })
|
||||||
t.Run("ActiveUserCount", func(t *testing.T) { testActiveUserCount(t, rctx, ss) })
|
t.Run("ActiveUserCount", func(t *testing.T) { testActiveUserCount(t, rctx, ss) })
|
||||||
t.Run("UpdateExpiredDNDStatuses", func(t *testing.T) { testUpdateExpiredDNDStatuses(t, rctx, ss) })
|
t.Run("UpdateExpiredDNDStatuses", func(t *testing.T) { testUpdateExpiredDNDStatuses(t, rctx, ss) })
|
||||||
|
t.Run("Get", func(t *testing.T) { testStatusGet(t, rctx, ss, s) })
|
||||||
|
t.Run("GetByIds", func(t *testing.T) { testStatusGetByIds(t, rctx, ss, s) })
|
||||||
}
|
}
|
||||||
|
|
||||||
func testStatusStore(t *testing.T, rctx request.CTX, ss store.Store) {
|
func testStatusStore(t *testing.T, _ request.CTX, ss store.Store) {
|
||||||
status := &model.Status{UserId: model.NewId(), Status: model.StatusOnline, Manual: false, LastActivityAt: 0, ActiveChannel: ""}
|
status := &model.Status{UserId: model.NewId(), Status: model.StatusOnline, Manual: false, LastActivityAt: 0, ActiveChannel: ""}
|
||||||
require.NoError(t, ss.Status().SaveOrUpdate(status))
|
require.NoError(t, ss.Status().SaveOrUpdate(status))
|
||||||
|
|
||||||
@ -51,12 +55,19 @@ func testStatusStore(t *testing.T, rctx request.CTX, ss store.Store) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func testActiveUserCount(t *testing.T, rctx request.CTX, ss store.Store) {
|
func testActiveUserCount(t *testing.T, rctx request.CTX, ss store.Store) {
|
||||||
status := &model.Status{UserId: model.NewId(), Status: model.StatusOnline, Manual: false, LastActivityAt: model.GetMillis(), ActiveChannel: ""}
|
status1 := &model.Status{UserId: model.NewId(), Status: model.StatusOnline, Manual: false, LastActivityAt: model.GetMillis(), ActiveChannel: ""}
|
||||||
require.NoError(t, ss.Status().SaveOrUpdate(status))
|
require.NoError(t, ss.Status().SaveOrUpdate(status1))
|
||||||
|
|
||||||
count, err := ss.Status().GetTotalActiveUsersCount()
|
count1, err := ss.Status().GetTotalActiveUsersCount()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.True(t, count > 0, "expected count > 0, got %d", count)
|
assert.Greater(t, count1, int64(0))
|
||||||
|
|
||||||
|
status2 := &model.Status{UserId: model.NewId(), Status: model.StatusOnline, Manual: false, LastActivityAt: model.GetMillis(), ActiveChannel: ""}
|
||||||
|
require.NoError(t, ss.Status().SaveOrUpdate(status2))
|
||||||
|
|
||||||
|
count2, err := ss.Status().GetTotalActiveUsersCount()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, count1+1, count2)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ByUserId []*model.Status
|
type ByUserId []*model.Status
|
||||||
@ -68,8 +79,15 @@ func (s ByUserId) Less(i, j int) bool { return s[i].UserId < s[j].UserId }
|
|||||||
func testUpdateExpiredDNDStatuses(t *testing.T, rctx request.CTX, ss store.Store) {
|
func testUpdateExpiredDNDStatuses(t *testing.T, rctx request.CTX, ss store.Store) {
|
||||||
userID := NewTestID()
|
userID := NewTestID()
|
||||||
|
|
||||||
status := &model.Status{UserId: userID, Status: model.StatusDnd, Manual: true,
|
status := &model.Status{
|
||||||
DNDEndTime: time.Now().Add(5 * time.Second).Unix(), PrevStatus: model.StatusOnline}
|
UserId: userID,
|
||||||
|
Status: model.StatusDnd,
|
||||||
|
Manual: true,
|
||||||
|
LastActivityAt: time.Now().Unix(),
|
||||||
|
ActiveChannel: "channel-id",
|
||||||
|
DNDEndTime: time.Now().Add(5 * time.Second).Unix(),
|
||||||
|
PrevStatus: model.StatusOnline,
|
||||||
|
}
|
||||||
require.NoError(t, ss.Status().SaveOrUpdate(status))
|
require.NoError(t, ss.Status().SaveOrUpdate(status))
|
||||||
|
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
@ -87,9 +105,181 @@ func testUpdateExpiredDNDStatuses(t *testing.T, rctx request.CTX, ss store.Store
|
|||||||
require.Len(t, statuses, 1)
|
require.Len(t, statuses, 1)
|
||||||
|
|
||||||
updatedStatus := *statuses[0]
|
updatedStatus := *statuses[0]
|
||||||
require.Equal(t, updatedStatus.UserId, userID)
|
assert.Equal(t, updatedStatus.UserId, userID)
|
||||||
require.Equal(t, updatedStatus.Status, model.StatusOnline)
|
assert.Equal(t, updatedStatus.Status, model.StatusOnline)
|
||||||
require.Equal(t, updatedStatus.DNDEndTime, int64(0))
|
assert.Equal(t, updatedStatus.Manual, false)
|
||||||
require.Equal(t, updatedStatus.PrevStatus, model.StatusDnd)
|
assert.Equal(t, updatedStatus.LastActivityAt, updatedStatus.LastActivityAt)
|
||||||
require.Equal(t, updatedStatus.Manual, false)
|
assert.Empty(t, updatedStatus.ActiveChannel)
|
||||||
|
assert.Equal(t, updatedStatus.DNDEndTime, int64(0))
|
||||||
|
assert.Equal(t, updatedStatus.PrevStatus, model.StatusDnd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func insertNullStatus(t *testing.T, ss store.Store, s SqlStore) string {
|
||||||
|
userId := model.NewId()
|
||||||
|
db := ss.GetInternalMasterDB()
|
||||||
|
|
||||||
|
// Insert status with explicit NULL values
|
||||||
|
builder := sq.StatementBuilder.PlaceholderFormat(s.GetQueryPlaceholder()).
|
||||||
|
Insert("Status").
|
||||||
|
Columns("UserId", "Status", quoteColumnName(s.DriverName(), "Manual"), "LastActivityAt", "DNDEndTime", "PrevStatus").
|
||||||
|
Values(userId, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
|
query, args, err := builder.ToSql()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = db.Exec(query, args...)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return userId
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStatusGet(t *testing.T, _ request.CTX, ss store.Store, s SqlStore) {
|
||||||
|
t.Run("null columns", func(t *testing.T) {
|
||||||
|
userId := insertNullStatus(t, ss, s)
|
||||||
|
|
||||||
|
received, err := ss.Status().Get(userId)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, userId, received.UserId)
|
||||||
|
assert.Empty(t, received.Status)
|
||||||
|
assert.False(t, received.Manual)
|
||||||
|
assert.Equal(t, int64(0), received.LastActivityAt)
|
||||||
|
assert.Empty(t, received.ActiveChannel)
|
||||||
|
assert.Equal(t, int64(0), received.DNDEndTime)
|
||||||
|
assert.Empty(t, received.PrevStatus)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("status1", func(t *testing.T) {
|
||||||
|
status1 := &model.Status{
|
||||||
|
UserId: model.NewId(),
|
||||||
|
Status: model.StatusDnd,
|
||||||
|
Manual: true,
|
||||||
|
LastActivityAt: 1234,
|
||||||
|
ActiveChannel: "channel-id",
|
||||||
|
DNDEndTime: model.GetMillis(),
|
||||||
|
PrevStatus: model.StatusOnline,
|
||||||
|
}
|
||||||
|
require.NoError(t, ss.Status().SaveOrUpdate(status1))
|
||||||
|
|
||||||
|
received, err := ss.Status().Get(status1.UserId)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, status1.UserId, received.UserId)
|
||||||
|
assert.Equal(t, status1.Status, received.Status)
|
||||||
|
assert.Equal(t, status1.Manual, received.Manual)
|
||||||
|
assert.Equal(t, status1.LastActivityAt, received.LastActivityAt)
|
||||||
|
assert.Empty(t, received.ActiveChannel)
|
||||||
|
assert.Equal(t, status1.DNDEndTime, received.DNDEndTime)
|
||||||
|
assert.Equal(t, status1.PrevStatus, received.PrevStatus)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("status2", func(t *testing.T) {
|
||||||
|
status2 := &model.Status{
|
||||||
|
UserId: model.NewId(),
|
||||||
|
Status: model.StatusOffline,
|
||||||
|
Manual: false,
|
||||||
|
LastActivityAt: 12345,
|
||||||
|
ActiveChannel: "channel-id2",
|
||||||
|
DNDEndTime: model.GetMillis(),
|
||||||
|
PrevStatus: model.StatusAway,
|
||||||
|
}
|
||||||
|
require.NoError(t, ss.Status().SaveOrUpdate(status2))
|
||||||
|
|
||||||
|
received, err := ss.Status().Get(status2.UserId)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, status2.UserId, received.UserId)
|
||||||
|
assert.Equal(t, status2.Status, received.Status)
|
||||||
|
assert.Equal(t, status2.Manual, received.Manual)
|
||||||
|
assert.Equal(t, status2.LastActivityAt, received.LastActivityAt)
|
||||||
|
assert.Empty(t, received.ActiveChannel)
|
||||||
|
assert.Equal(t, status2.DNDEndTime, received.DNDEndTime)
|
||||||
|
assert.Equal(t, status2.PrevStatus, received.PrevStatus)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStatusGetByIds(t *testing.T, _ request.CTX, ss store.Store, s SqlStore) {
|
||||||
|
t.Run("null columns, single user", func(t *testing.T) {
|
||||||
|
userId := insertNullStatus(t, ss, s)
|
||||||
|
|
||||||
|
received, err := ss.Status().GetByIds([]string{userId})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, received, 1)
|
||||||
|
assert.Equal(t, userId, received[0].UserId)
|
||||||
|
assert.Empty(t, received[0].Status)
|
||||||
|
assert.False(t, received[0].Manual)
|
||||||
|
assert.Equal(t, int64(0), received[0].LastActivityAt)
|
||||||
|
assert.Empty(t, received[0].ActiveChannel)
|
||||||
|
assert.Equal(t, int64(0), received[0].DNDEndTime)
|
||||||
|
assert.Empty(t, received[0].PrevStatus)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("single user", func(t *testing.T) {
|
||||||
|
status1 := &model.Status{
|
||||||
|
UserId: model.NewId(),
|
||||||
|
Status: model.StatusDnd,
|
||||||
|
Manual: true,
|
||||||
|
LastActivityAt: 1234,
|
||||||
|
ActiveChannel: "channel-id",
|
||||||
|
DNDEndTime: model.GetMillis(),
|
||||||
|
PrevStatus: model.StatusOnline,
|
||||||
|
}
|
||||||
|
require.NoError(t, ss.Status().SaveOrUpdate(status1))
|
||||||
|
|
||||||
|
received, err := ss.Status().GetByIds([]string{status1.UserId})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, received, 1)
|
||||||
|
assert.Equal(t, status1.UserId, received[0].UserId)
|
||||||
|
assert.Equal(t, status1.Status, received[0].Status)
|
||||||
|
assert.Equal(t, status1.Manual, received[0].Manual)
|
||||||
|
assert.Equal(t, status1.LastActivityAt, received[0].LastActivityAt)
|
||||||
|
assert.Empty(t, received[0].ActiveChannel)
|
||||||
|
assert.Equal(t, status1.DNDEndTime, received[0].DNDEndTime)
|
||||||
|
assert.Equal(t, status1.PrevStatus, received[0].PrevStatus)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple users", func(t *testing.T) {
|
||||||
|
status1 := &model.Status{
|
||||||
|
UserId: model.NewId(),
|
||||||
|
Status: model.StatusDnd,
|
||||||
|
Manual: true,
|
||||||
|
LastActivityAt: 1234,
|
||||||
|
ActiveChannel: "channel-id",
|
||||||
|
DNDEndTime: model.GetMillis(),
|
||||||
|
PrevStatus: model.StatusOnline,
|
||||||
|
}
|
||||||
|
require.NoError(t, ss.Status().SaveOrUpdate(status1))
|
||||||
|
|
||||||
|
status2 := &model.Status{
|
||||||
|
UserId: model.NewId(),
|
||||||
|
Status: model.StatusOffline,
|
||||||
|
Manual: false,
|
||||||
|
LastActivityAt: 12345,
|
||||||
|
ActiveChannel: "channel-id2",
|
||||||
|
DNDEndTime: model.GetMillis(),
|
||||||
|
PrevStatus: model.StatusAway,
|
||||||
|
}
|
||||||
|
require.NoError(t, ss.Status().SaveOrUpdate(status2))
|
||||||
|
|
||||||
|
received, err := ss.Status().GetByIds([]string{status1.UserId, status2.UserId})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, received, 2)
|
||||||
|
|
||||||
|
for _, status := range received {
|
||||||
|
if status.UserId == status1.UserId {
|
||||||
|
assert.Equal(t, status1.UserId, status.UserId)
|
||||||
|
assert.Equal(t, status1.Status, status.Status)
|
||||||
|
assert.Equal(t, status1.Manual, status.Manual)
|
||||||
|
assert.Equal(t, status1.LastActivityAt, status.LastActivityAt)
|
||||||
|
assert.Empty(t, status.ActiveChannel)
|
||||||
|
assert.Equal(t, status1.DNDEndTime, status.DNDEndTime)
|
||||||
|
assert.Equal(t, status1.PrevStatus, status.PrevStatus)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, status2.UserId, status.UserId)
|
||||||
|
assert.Equal(t, status2.Status, status.Status)
|
||||||
|
assert.Equal(t, status2.Manual, status.Manual)
|
||||||
|
assert.Equal(t, status2.LastActivityAt, status.LastActivityAt)
|
||||||
|
assert.Empty(t, status.ActiveChannel)
|
||||||
|
assert.Equal(t, status2.DNDEndTime, status.DNDEndTime)
|
||||||
|
assert.Equal(t, status2.PrevStatus, status.PrevStatus)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,8 @@
|
|||||||
package storetest
|
package storetest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/mattermost/mattermost/server/public/model"
|
"github.com/mattermost/mattermost/server/public/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -19,3 +21,16 @@ func NewTestID() string {
|
|||||||
|
|
||||||
return string(newID)
|
return string(newID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Adds backtiks to the column name for MySQL, this is required if
|
||||||
|
// the column name is a reserved keyword.
|
||||||
|
//
|
||||||
|
// `ColumnName` - MySQL
|
||||||
|
// ColumnName - Postgres
|
||||||
|
func quoteColumnName(driver string, columnName string) string {
|
||||||
|
if driver == model.DatabaseDriverMysql {
|
||||||
|
return fmt.Sprintf("`%s`", columnName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return columnName
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user