MM-34609 Mark-as-unread on a post in a thread should cause auto-follow (#17343)

Co-authored-by: Mattermod <mattermod@users.noreply.github.com>
This commit is contained in:
Eli Yukelzon
2021-04-16 10:26:08 +03:00
committed by GitHub
parent 518e0ed371
commit f90209c8a3
11 changed files with 78 additions and 35 deletions

View File

@@ -11,6 +11,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"reflect"
"sort"
"strings"
@@ -2544,3 +2545,30 @@ func TestSetChannelUnread(t *testing.T) {
checkHTTPStatus(t, response, http.StatusUnauthorized, true)
})
}
func TestMarkUnreadCausesAutofollow(t *testing.T) {
os.Setenv("MM_FEATUREFLAGS_COLLAPSEDTHREADS", "true")
defer os.Unsetenv("MM_FEATUREFLAGS_COLLAPSEDTHREADS")
th := Setup(t).InitBasic()
defer th.TearDown()
th.App.UpdateConfig(func(cfg *model.Config) {
*cfg.ServiceSettings.ThreadAutoFollow = true
*cfg.ServiceSettings.CollapsedThreads = model.COLLAPSED_THREADS_DEFAULT_ON
})
rootPost, appErr := th.App.CreatePost(&model.Post{UserId: th.BasicUser2.Id, CreateAt: model.GetMillis(), ChannelId: th.BasicChannel.Id, Message: "hi"}, th.BasicChannel, false, false)
require.Nil(t, appErr)
replyPost, appErr := th.App.CreatePost(&model.Post{RootId: rootPost.Id, UserId: th.BasicUser2.Id, CreateAt: model.GetMillis(), ChannelId: th.BasicChannel.Id, Message: "hi"}, th.BasicChannel, false, false)
require.Nil(t, appErr)
threads, appErr := th.App.GetThreadsForUser(th.BasicUser.Id, th.BasicTeam.Id, model.GetUserThreadsOpts{})
require.Nil(t, appErr)
require.Zero(t, threads.Total)
_, appErr = th.App.MarkChannelAsUnreadFromPost(replyPost.Id, th.BasicUser.Id)
require.Nil(t, appErr)
threads, appErr = th.App.GetThreadsForUser(th.BasicUser.Id, th.BasicTeam.Id, model.GetUserThreadsOpts{})
require.Nil(t, appErr)
require.NotZero(t, threads.Total)
}

View File

@@ -2385,6 +2385,9 @@ func (a *App) MarkChannelAsUnreadFromPost(postID string, userID string) (*model.
}
threadMembership, _ := a.Srv().Store.Thread().GetMembershipForUser(user.Id, threadId)
if threadMembership == nil {
threadMembership, _ = a.Srv().Store.Thread().MaintainMembership(user.Id, threadId, true, true, true, true)
}
if threadMembership != nil && threadMembership.Following {
channel, nErr := a.Srv().Store.Channel().Get(post.ChannelId, true)
if nErr != nil {

View File

@@ -185,7 +185,7 @@ func (a *App) SendNotifications(post *model.Post, team *model.Team, channel *mod
go func(userID string) {
defer close(mac)
_, incrementMentions := mentions.Mentions[userID]
err := a.Srv().Store.Thread().MaintainMembership(userID, post.RootId, true, incrementMentions, *a.Config().ServiceSettings.ThreadAutoFollow, userID == post.UserId)
_, err := a.Srv().Store.Thread().MaintainMembership(userID, post.RootId, true, incrementMentions, *a.Config().ServiceSettings.ThreadAutoFollow, userID == post.UserId)
if err != nil {
mac <- model.NewAppError("SendNotifications", "app.channel.autofollow.app_error", nil, err.Error(), http.StatusInternalServerError)
return

View File

@@ -2424,7 +2424,7 @@ func (a *App) UpdateThreadsReadForUser(userID, teamID string) *model.AppError {
}
func (a *App) UpdateThreadFollowForUser(userID, teamID, threadID string, state bool) *model.AppError {
err := a.Srv().Store.Thread().MaintainMembership(userID, threadID, state, false, true, false)
_, err := a.Srv().Store.Thread().MaintainMembership(userID, threadID, state, false, true, false)
if err != nil {
return model.NewAppError("UpdateThreadFollowForUser", "app.user.update_thread_follow_for_user.app_error", nil, err.Error(), http.StatusInternalServerError)
}

View File

@@ -8550,7 +8550,7 @@ func (s *OpenTracingLayerThreadStore) GetThreadsForUser(userId string, teamID st
return result, err
}
func (s *OpenTracingLayerThreadStore) MaintainMembership(userID string, postID string, following bool, incrementMentions bool, updateFollowing bool, updateViewedTimestamp bool) error {
func (s *OpenTracingLayerThreadStore) MaintainMembership(userID string, postID string, following bool, incrementMentions bool, updateFollowing bool, updateViewedTimestamp bool) (*model.ThreadMembership, error) {
origCtx := s.Root.Store.Context()
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ThreadStore.MaintainMembership")
s.Root.Store.SetContext(newCtx)
@@ -8559,13 +8559,13 @@ func (s *OpenTracingLayerThreadStore) MaintainMembership(userID string, postID s
}()
defer span.Finish()
err := s.ThreadStore.MaintainMembership(userID, postID, following, incrementMentions, updateFollowing, updateViewedTimestamp)
result, err := s.ThreadStore.MaintainMembership(userID, postID, following, incrementMentions, updateFollowing, updateViewedTimestamp)
if err != nil {
span.LogFields(spanlog.Error(err))
ext.Error.Set(span, true)
}
return err
return result, err
}
func (s *OpenTracingLayerThreadStore) MarkAllAsRead(userID string, teamID string) error {

View File

@@ -9298,21 +9298,21 @@ func (s *RetryLayerThreadStore) GetThreadsForUser(userId string, teamID string,
}
func (s *RetryLayerThreadStore) MaintainMembership(userID string, postID string, following bool, incrementMentions bool, updateFollowing bool, updateViewedTimestamp bool) error {
func (s *RetryLayerThreadStore) MaintainMembership(userID string, postID string, following bool, incrementMentions bool, updateFollowing bool, updateViewedTimestamp bool) (*model.ThreadMembership, error) {
tries := 0
for {
err := s.ThreadStore.MaintainMembership(userID, postID, following, incrementMentions, updateFollowing, updateViewedTimestamp)
result, err := s.ThreadStore.MaintainMembership(userID, postID, following, incrementMentions, updateFollowing, updateViewedTimestamp)
if err == nil {
return nil
return result, nil
}
if !isRepeatableError(err) {
return err
return result, err
}
tries++
if tries >= 3 {
err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures")
return err
return result, err
}
}

View File

@@ -472,7 +472,7 @@ func (s *SqlThreadStore) DeleteMembershipForUser(userId string, postId string) e
return nil
}
func (s *SqlThreadStore) MaintainMembership(userId, postId string, following, incrementMentions, updateFollowing, updateViewedTimestamp bool) error {
func (s *SqlThreadStore) MaintainMembership(userId, postId string, following, incrementMentions, updateFollowing, updateViewedTimestamp bool) (*model.ThreadMembership, error) {
membership, err := s.GetMembershipForUser(userId, postId)
now := utils.MillisFromTime(time.Now())
// if memebership exists, update it if:
@@ -494,19 +494,19 @@ func (s *SqlThreadStore) MaintainMembership(userId, postId string, following, in
}
_, err = s.UpdateMembership(membership)
}
return err
return nil, err
}
var nfErr *store.ErrNotFound
if !errors.As(err, &nfErr) {
return errors.Wrap(err, "failed to get thread membership")
return nil, errors.Wrap(err, "failed to get thread membership")
}
mentions := 0
if incrementMentions {
mentions = 1
}
_, err = s.SaveMembership(&model.ThreadMembership{
membership, err = s.SaveMembership(&model.ThreadMembership{
PostId: postId,
UserId: userId,
Following: following,
@@ -515,18 +515,18 @@ func (s *SqlThreadStore) MaintainMembership(userId, postId string, following, in
UnreadMentions: int64(mentions),
})
if err != nil {
return err
return nil, err
}
thread, err := s.Get(postId)
if err != nil {
return err
return nil, err
}
if !thread.Participants.Contains(userId) {
thread.Participants = append(thread.Participants, userId)
_, err = s.Update(thread)
}
return err
return membership, err
}
func (s *SqlThreadStore) CollectThreadsWithNewerReplies(userId string, channelIds []string, timestamp int64) ([]string, error) {

View File

@@ -272,7 +272,7 @@ type ThreadStore interface {
GetMembershipsForUser(userId, teamID string) ([]*model.ThreadMembership, error)
GetMembershipForUser(userId, postID string) (*model.ThreadMembership, error)
DeleteMembershipForUser(userId, postID string) error
MaintainMembership(userID, postID string, following, incrementMentions, updateFollowing, updateViewedTimestamp bool) error
MaintainMembership(userID, postID string, following, incrementMentions, updateFollowing, updateViewedTimestamp bool) (*model.ThreadMembership, error)
CollectThreadsWithNewerReplies(userId string, channelIds []string, timestamp int64) ([]string, error)
UpdateUnreadsByChannel(userId string, changedThreads []string, timestamp int64, updateViewedTimestamp bool) error
}

View File

@@ -204,17 +204,26 @@ func (_m *ThreadStore) GetThreadsForUser(userId string, teamID string, opts mode
}
// MaintainMembership provides a mock function with given fields: userID, postID, following, incrementMentions, updateFollowing, updateViewedTimestamp
func (_m *ThreadStore) MaintainMembership(userID string, postID string, following bool, incrementMentions bool, updateFollowing bool, updateViewedTimestamp bool) error {
func (_m *ThreadStore) MaintainMembership(userID string, postID string, following bool, incrementMentions bool, updateFollowing bool, updateViewedTimestamp bool) (*model.ThreadMembership, error) {
ret := _m.Called(userID, postID, following, incrementMentions, updateFollowing, updateViewedTimestamp)
var r0 error
if rf, ok := ret.Get(0).(func(string, string, bool, bool, bool, bool) error); ok {
var r0 *model.ThreadMembership
if rf, ok := ret.Get(0).(func(string, string, bool, bool, bool, bool) *model.ThreadMembership); ok {
r0 = rf(userID, postID, following, incrementMentions, updateFollowing, updateViewedTimestamp)
} else {
r0 = ret.Error(0)
if ret.Get(0) != nil {
r0 = ret.Get(0).(*model.ThreadMembership)
}
}
return r0
var r1 error
if rf, ok := ret.Get(1).(func(string, string, bool, bool, bool, bool) error); ok {
r1 = rf(userID, postID, following, incrementMentions, updateFollowing, updateViewedTimestamp)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MarkAllAsRead provides a mock function with given fields: userID, teamID

View File

@@ -235,8 +235,8 @@ func testThreadStorePopulation(t *testing.T, ss store.Store) {
t.Run("Thread last updated is changed when channel is updated after UpdateLastViewedAtPost", func(t *testing.T) {
newPosts := makeSomePosts()
require.NoError(t, ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, true, false, true, false))
_, e := ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, true, false, true, false)
require.NoError(t, e)
m, err1 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
require.NoError(t, err1)
m.LastUpdated -= 1000
@@ -256,7 +256,8 @@ func testThreadStorePopulation(t *testing.T, ss store.Store) {
t.Run("Thread last updated is changed when channel is updated after IncrementMentionCount", func(t *testing.T) {
newPosts := makeSomePosts()
require.NoError(t, ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, true, false, true, false))
_, e := ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, true, false, true, false)
require.NoError(t, e)
m, err1 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
require.NoError(t, err1)
m.LastUpdated -= 1000
@@ -275,8 +276,8 @@ func testThreadStorePopulation(t *testing.T, ss store.Store) {
t.Run("Thread last updated is changed when channel is updated after UpdateLastViewedAt", func(t *testing.T) {
newPosts := makeSomePosts()
require.NoError(t, ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, true, false, true, false))
_, e := ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, true, false, true, false)
require.NoError(t, e)
m, err1 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
require.NoError(t, err1)
m.LastUpdated -= 1000
@@ -296,12 +297,14 @@ func testThreadStorePopulation(t *testing.T, ss store.Store) {
t.Run("Thread membership 'viewed' timestamp is updated properly", func(t *testing.T) {
newPosts := makeSomePosts()
require.NoError(t, ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, true, false, true, false))
_, e := ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, true, false, true, false)
require.NoError(t, e)
m, err1 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
require.NoError(t, err1)
require.Equal(t, int64(0), m.LastViewed)
require.NoError(t, ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, true, false, true, true))
_, e = ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, true, false, true, true)
require.NoError(t, e)
m2, err2 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
require.NoError(t, err2)
require.Greater(t, m2.LastViewed, int64(0))
@@ -309,8 +312,8 @@ func testThreadStorePopulation(t *testing.T, ss store.Store) {
t.Run("Thread last updated is changed when channel is updated after UpdateLastViewedAtPost for mark unread", func(t *testing.T) {
newPosts := makeSomePosts()
require.NoError(t, ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, true, false, true, false))
_, e := ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, true, false, true, false)
require.NoError(t, e)
m, err1 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
require.NoError(t, err1)
m.LastUpdated += 1000

View File

@@ -7708,10 +7708,10 @@ func (s *TimerLayerThreadStore) GetThreadsForUser(userId string, teamID string,
return result, err
}
func (s *TimerLayerThreadStore) MaintainMembership(userID string, postID string, following bool, incrementMentions bool, updateFollowing bool, updateViewedTimestamp bool) error {
func (s *TimerLayerThreadStore) MaintainMembership(userID string, postID string, following bool, incrementMentions bool, updateFollowing bool, updateViewedTimestamp bool) (*model.ThreadMembership, error) {
start := timemodule.Now()
err := s.ThreadStore.MaintainMembership(userID, postID, following, incrementMentions, updateFollowing, updateViewedTimestamp)
result, err := s.ThreadStore.MaintainMembership(userID, postID, following, incrementMentions, updateFollowing, updateViewedTimestamp)
elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second)
if s.Root.Metrics != nil {
@@ -7721,7 +7721,7 @@ func (s *TimerLayerThreadStore) MaintainMembership(userID string, postID string,
}
s.Root.Metrics.ObserveStoreMethodDuration("ThreadStore.MaintainMembership", success, elapsed)
}
return err
return result, err
}
func (s *TimerLayerThreadStore) MarkAllAsRead(userID string, teamID string) error {