mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
333 lines
11 KiB
Go
333 lines
11 KiB
Go
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
|
|
// See LICENSE.txt for license information.
|
|
|
|
package storetest
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/mattermost/mattermost-server/v5/model"
|
|
"github.com/mattermost/mattermost-server/v5/store"
|
|
)
|
|
|
|
func TestThreadStore(t *testing.T, ss store.Store, s SqlStore) {
|
|
t.Run("ThreadStorePopulation", func(t *testing.T) { testThreadStorePopulation(t, ss) })
|
|
}
|
|
|
|
func testThreadStorePopulation(t *testing.T, ss store.Store) {
|
|
makeSomePosts := func() []*model.Post {
|
|
|
|
u1 := model.User{
|
|
Email: MakeEmail(),
|
|
Username: model.NewId(),
|
|
}
|
|
|
|
u, err := ss.User().Save(&u1)
|
|
require.NoError(t, err)
|
|
|
|
c, err2 := ss.Channel().Save(&model.Channel{
|
|
DisplayName: model.NewId(),
|
|
Type: model.CHANNEL_OPEN,
|
|
Name: model.NewId(),
|
|
}, 999)
|
|
require.NoError(t, err2)
|
|
|
|
_, err44 := ss.Channel().SaveMember(&model.ChannelMember{
|
|
ChannelId: c.Id,
|
|
UserId: u1.Id,
|
|
NotifyProps: model.GetDefaultChannelNotifyProps(),
|
|
MsgCount: 0,
|
|
})
|
|
require.NoError(t, err44)
|
|
o := model.Post{}
|
|
o.ChannelId = c.Id
|
|
o.UserId = u.Id
|
|
o.Message = "zz" + model.NewId() + "b"
|
|
|
|
otmp, err3 := ss.Post().Save(&o)
|
|
require.NoError(t, err3)
|
|
o2 := model.Post{}
|
|
o2.ChannelId = c.Id
|
|
o2.UserId = model.NewId()
|
|
o2.RootId = otmp.Id
|
|
o2.Message = "zz" + model.NewId() + "b"
|
|
|
|
o3 := model.Post{}
|
|
o3.ChannelId = c.Id
|
|
o3.UserId = u.Id
|
|
o3.RootId = otmp.Id
|
|
o3.Message = "zz" + model.NewId() + "b"
|
|
|
|
o4 := model.Post{}
|
|
o4.ChannelId = c.Id
|
|
o4.UserId = model.NewId()
|
|
o4.Message = "zz" + model.NewId() + "b"
|
|
|
|
newPosts, errIdx, err3 := ss.Post().SaveMultiple([]*model.Post{&o2, &o3, &o4})
|
|
|
|
olist, _ := ss.Post().Get(context.Background(), otmp.Id, true, false, false, "")
|
|
o1 := olist.Posts[olist.Order[0]]
|
|
|
|
newPosts = append([]*model.Post{o1}, newPosts...)
|
|
require.NoError(t, err3, "couldn't save item")
|
|
require.Equal(t, -1, errIdx)
|
|
require.Len(t, newPosts, 4)
|
|
require.Equal(t, int64(2), newPosts[0].ReplyCount)
|
|
require.Equal(t, int64(2), newPosts[1].ReplyCount)
|
|
require.Equal(t, int64(2), newPosts[2].ReplyCount)
|
|
require.Equal(t, int64(0), newPosts[3].ReplyCount)
|
|
|
|
return newPosts
|
|
}
|
|
t.Run("Save replies creates a thread", func(t *testing.T) {
|
|
newPosts := makeSomePosts()
|
|
thread, err := ss.Thread().Get(newPosts[0].Id)
|
|
require.NoError(t, err, "couldn't get thread")
|
|
require.NotNil(t, thread)
|
|
require.Equal(t, int64(2), thread.ReplyCount)
|
|
require.ElementsMatch(t, model.StringArray{newPosts[0].UserId, newPosts[1].UserId}, thread.Participants)
|
|
|
|
o5 := model.Post{}
|
|
o5.ChannelId = model.NewId()
|
|
o5.UserId = model.NewId()
|
|
o5.RootId = newPosts[0].Id
|
|
o5.Message = "zz" + model.NewId() + "b"
|
|
|
|
_, _, err = ss.Post().SaveMultiple([]*model.Post{&o5})
|
|
require.NoError(t, err, "couldn't save item")
|
|
|
|
thread, err = ss.Thread().Get(newPosts[0].Id)
|
|
require.NoError(t, err, "couldn't get thread")
|
|
require.NotNil(t, thread)
|
|
require.Equal(t, int64(3), thread.ReplyCount)
|
|
require.ElementsMatch(t, model.StringArray{newPosts[0].UserId, newPosts[1].UserId, o5.UserId}, thread.Participants)
|
|
})
|
|
|
|
t.Run("Delete a reply updates count on a thread", func(t *testing.T) {
|
|
newPosts := makeSomePosts()
|
|
thread, err := ss.Thread().Get(newPosts[0].Id)
|
|
require.NoError(t, err, "couldn't get thread")
|
|
require.NotNil(t, thread)
|
|
require.Equal(t, int64(2), thread.ReplyCount)
|
|
require.ElementsMatch(t, model.StringArray{newPosts[0].UserId, newPosts[1].UserId}, thread.Participants)
|
|
|
|
err = ss.Post().Delete(newPosts[1].Id, 1234, model.NewId())
|
|
require.NoError(t, err, "couldn't delete post")
|
|
|
|
thread, err = ss.Thread().Get(newPosts[0].Id)
|
|
require.NoError(t, err, "couldn't get thread")
|
|
require.NotNil(t, thread)
|
|
require.Equal(t, int64(1), thread.ReplyCount)
|
|
require.ElementsMatch(t, model.StringArray{newPosts[0].UserId, newPosts[1].UserId}, thread.Participants)
|
|
})
|
|
|
|
t.Run("Update reply should update the UpdateAt of the thread", func(t *testing.T) {
|
|
rootPost := model.Post{}
|
|
rootPost.RootId = model.NewId()
|
|
rootPost.ChannelId = model.NewId()
|
|
rootPost.UserId = model.NewId()
|
|
rootPost.Message = "zz" + model.NewId() + "b"
|
|
|
|
replyPost := model.Post{}
|
|
replyPost.ChannelId = rootPost.ChannelId
|
|
replyPost.UserId = model.NewId()
|
|
replyPost.Message = "zz" + model.NewId() + "b"
|
|
replyPost.RootId = rootPost.RootId
|
|
|
|
newPosts, _, err := ss.Post().SaveMultiple([]*model.Post{&rootPost, &replyPost})
|
|
require.NoError(t, err)
|
|
|
|
thread1, err := ss.Thread().Get(newPosts[0].RootId)
|
|
require.NoError(t, err)
|
|
|
|
rrootPost, err := ss.Post().GetSingle(rootPost.Id, false)
|
|
require.NoError(t, err)
|
|
require.Equal(t, rrootPost.UpdateAt, rootPost.UpdateAt)
|
|
|
|
replyPost2 := model.Post{}
|
|
replyPost2.ChannelId = rootPost.ChannelId
|
|
replyPost2.UserId = model.NewId()
|
|
replyPost2.Message = "zz" + model.NewId() + "b"
|
|
replyPost2.RootId = rootPost.Id
|
|
|
|
replyPost3 := model.Post{}
|
|
replyPost3.ChannelId = rootPost.ChannelId
|
|
replyPost3.UserId = model.NewId()
|
|
replyPost3.Message = "zz" + model.NewId() + "b"
|
|
replyPost3.RootId = rootPost.Id
|
|
|
|
_, _, err = ss.Post().SaveMultiple([]*model.Post{&replyPost2, &replyPost3})
|
|
require.NoError(t, err)
|
|
|
|
rrootPost2, err := ss.Post().GetSingle(rootPost.Id, false)
|
|
require.NoError(t, err)
|
|
require.Greater(t, rrootPost2.UpdateAt, rrootPost.UpdateAt)
|
|
|
|
thread2, err := ss.Thread().Get(rootPost.Id)
|
|
require.NoError(t, err)
|
|
require.Greater(t, thread2.LastReplyAt, thread1.LastReplyAt)
|
|
})
|
|
|
|
t.Run("Deleting reply should update the thread", func(t *testing.T) {
|
|
rootPost := model.Post{}
|
|
rootPost.RootId = model.NewId()
|
|
rootPost.ChannelId = model.NewId()
|
|
rootPost.UserId = model.NewId()
|
|
rootPost.Message = "zz" + model.NewId() + "b"
|
|
|
|
replyPost := model.Post{}
|
|
replyPost.ChannelId = rootPost.ChannelId
|
|
replyPost.UserId = model.NewId()
|
|
replyPost.Message = "zz" + model.NewId() + "b"
|
|
replyPost.RootId = rootPost.RootId
|
|
|
|
newPosts, _, err := ss.Post().SaveMultiple([]*model.Post{&rootPost, &replyPost})
|
|
require.NoError(t, err)
|
|
|
|
thread1, err := ss.Thread().Get(newPosts[0].RootId)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, thread1.ReplyCount, 2)
|
|
require.Len(t, thread1.Participants, 2)
|
|
|
|
err = ss.Post().Delete(replyPost.Id, 123, model.NewId())
|
|
require.NoError(t, err)
|
|
|
|
thread2, err := ss.Thread().Get(rootPost.RootId)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, thread2.ReplyCount, 1)
|
|
require.Len(t, thread2.Participants, 2)
|
|
})
|
|
|
|
t.Run("Deleting root post should delete the thread", func(t *testing.T) {
|
|
rootPost := model.Post{}
|
|
rootPost.ChannelId = model.NewId()
|
|
rootPost.UserId = model.NewId()
|
|
rootPost.Message = "zz" + model.NewId() + "b"
|
|
|
|
newPosts1, _, err := ss.Post().SaveMultiple([]*model.Post{&rootPost})
|
|
require.NoError(t, err)
|
|
|
|
replyPost := model.Post{}
|
|
replyPost.ChannelId = rootPost.ChannelId
|
|
replyPost.UserId = model.NewId()
|
|
replyPost.Message = "zz" + model.NewId() + "b"
|
|
replyPost.RootId = newPosts1[0].Id
|
|
|
|
_, _, err = ss.Post().SaveMultiple([]*model.Post{&replyPost})
|
|
require.NoError(t, err)
|
|
|
|
thread1, err := ss.Thread().Get(newPosts1[0].Id)
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, thread1.ReplyCount, 1)
|
|
require.Len(t, thread1.Participants, 2)
|
|
|
|
err = ss.Post().PermanentDeleteByUser(rootPost.UserId)
|
|
require.NoError(t, err)
|
|
|
|
thread2, _ := ss.Thread().Get(rootPost.Id)
|
|
require.Nil(t, thread2)
|
|
})
|
|
|
|
t.Run("Thread last updated is changed when channel is updated after UpdateLastViewedAtPost", func(t *testing.T) {
|
|
newPosts := makeSomePosts()
|
|
_, e := ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, true, false, true, false, false)
|
|
require.NoError(t, e)
|
|
m, err1 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
|
|
require.NoError(t, err1)
|
|
m.LastUpdated -= 1000
|
|
_, err := ss.Thread().UpdateMembership(m)
|
|
require.NoError(t, err)
|
|
|
|
_, err = ss.Channel().UpdateLastViewedAtPost(newPosts[0], newPosts[0].UserId, 0, 0, true)
|
|
require.NoError(t, err)
|
|
|
|
assert.Eventually(t, func() bool {
|
|
m2, err2 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
|
|
require.NoError(t, err2)
|
|
return m2.LastUpdated > m.LastUpdated
|
|
}, time.Second, 10*time.Millisecond)
|
|
})
|
|
|
|
t.Run("Thread last updated is changed when channel is updated after IncrementMentionCount", func(t *testing.T) {
|
|
newPosts := makeSomePosts()
|
|
|
|
_, e := ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, true, false, true, false, false)
|
|
require.NoError(t, e)
|
|
m, err1 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
|
|
require.NoError(t, err1)
|
|
m.LastUpdated -= 1000
|
|
_, err := ss.Thread().UpdateMembership(m)
|
|
require.NoError(t, err)
|
|
|
|
err = ss.Channel().IncrementMentionCount(newPosts[0].ChannelId, newPosts[0].UserId, true, false)
|
|
require.NoError(t, err)
|
|
|
|
assert.Eventually(t, func() bool {
|
|
m2, err2 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
|
|
require.NoError(t, err2)
|
|
return m2.LastUpdated > m.LastUpdated
|
|
}, time.Second, 10*time.Millisecond)
|
|
})
|
|
|
|
t.Run("Thread last updated is changed when channel is updated after UpdateLastViewedAt", func(t *testing.T) {
|
|
newPosts := makeSomePosts()
|
|
_, e := ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, true, false, true, false, false)
|
|
require.NoError(t, e)
|
|
m, err1 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
|
|
require.NoError(t, err1)
|
|
m.LastUpdated -= 1000
|
|
_, err := ss.Thread().UpdateMembership(m)
|
|
require.NoError(t, err)
|
|
|
|
_, err = ss.Channel().UpdateLastViewedAt([]string{newPosts[0].ChannelId}, newPosts[0].UserId, true)
|
|
require.NoError(t, err)
|
|
|
|
assert.Eventually(t, func() bool {
|
|
m2, err2 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
|
|
require.NoError(t, err2)
|
|
return m2.LastUpdated > m.LastUpdated
|
|
}, time.Second, 10*time.Millisecond)
|
|
})
|
|
|
|
t.Run("Thread membership 'viewed' timestamp is updated properly", func(t *testing.T) {
|
|
newPosts := makeSomePosts()
|
|
|
|
_, e := ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, true, false, true, false, false)
|
|
require.NoError(t, e)
|
|
m, err1 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
|
|
require.NoError(t, err1)
|
|
require.Equal(t, int64(0), m.LastViewed)
|
|
|
|
_, e = ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, true, false, true, true, false)
|
|
require.NoError(t, e)
|
|
m2, err2 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
|
|
require.NoError(t, err2)
|
|
require.Greater(t, m2.LastViewed, int64(0))
|
|
})
|
|
|
|
t.Run("Thread last updated is changed when channel is updated after UpdateLastViewedAtPost for mark unread", func(t *testing.T) {
|
|
newPosts := makeSomePosts()
|
|
_, e := ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, true, false, true, false, false)
|
|
require.NoError(t, e)
|
|
m, err1 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
|
|
require.NoError(t, err1)
|
|
m.LastUpdated += 1000
|
|
_, err := ss.Thread().UpdateMembership(m)
|
|
require.NoError(t, err)
|
|
|
|
_, err = ss.Channel().UpdateLastViewedAtPost(newPosts[0], newPosts[0].UserId, 0, 0, true)
|
|
require.NoError(t, err)
|
|
|
|
assert.Eventually(t, func() bool {
|
|
m2, err2 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
|
|
require.NoError(t, err2)
|
|
return m2.LastUpdated < m.LastUpdated
|
|
}, time.Second, 10*time.Millisecond)
|
|
})
|
|
}
|