[MM-32622] Remove app.WaitForChannelMembership() (#17048)

* Remove app.WaitForChannelMembership

* Fix tests

* Fix test

Co-authored-by: Mattermod <mattermod@users.noreply.github.com>
This commit is contained in:
Claudio Costa
2021-03-31 09:40:35 +02:00
committed by GitHub
parent 4ba0c09fc7
commit ee3f986da0
32 changed files with 211 additions and 227 deletions

View File

@@ -1055,7 +1055,7 @@ func (th *TestHelper) cleanupTestFile(info *model.FileInfo) error {
func (th *TestHelper) MakeUserChannelAdmin(user *model.User, channel *model.Channel) {
utils.DisableDebugLogForTest()
if cm, err := th.App.Srv().Store.Channel().GetMember(channel.Id, user.Id); err == nil {
if cm, err := th.App.Srv().Store.Channel().GetMember(context.Background(), channel.Id, user.Id); err == nil {
cm.SchemeAdmin = true
if _, err = th.App.Srv().Store.Channel().UpdateMember(cm); err != nil {
utils.EnableDebugLogForTest()

View File

@@ -4,6 +4,7 @@
package api4
import (
"context"
"encoding/json"
"net/http"
"strconv"
@@ -158,7 +159,7 @@ func updateChannel(c *Context, w http.ResponseWriter, r *http.Request) {
case model.CHANNEL_GROUP, model.CHANNEL_DIRECT:
// Modifying the header is not linked to any specific permission for group/dm channels, so just check for membership.
if _, errGet := c.App.GetChannelMember(channel.Id, c.App.Session().UserId); errGet != nil {
if _, errGet := c.App.GetChannelMember(context.Background(), channel.Id, c.App.Session().UserId); errGet != nil {
c.Err = model.NewAppError("updateChannel", "api.channel.patch_update_channel.forbidden.app_error", nil, "", http.StatusForbidden)
return
}
@@ -371,7 +372,7 @@ func patchChannel(c *Context, w http.ResponseWriter, r *http.Request) {
case model.CHANNEL_GROUP, model.CHANNEL_DIRECT:
// Modifying the header is not linked to any specific permission for group/dm channels, so just check for membership.
if _, err = c.App.GetChannelMember(c.Params.ChannelId, c.App.Session().UserId); err != nil {
if _, err = c.App.GetChannelMember(context.Background(), c.Params.ChannelId, c.App.Session().UserId); err != nil {
c.Err = model.NewAppError("patchChannel", "api.channel.patch_update_channel.forbidden.app_error", nil, "", http.StatusForbidden)
return
}
@@ -1249,7 +1250,7 @@ func getChannelMember(c *Context, w http.ResponseWriter, r *http.Request) {
return
}
member, err := c.App.GetChannelMember(c.Params.ChannelId, c.Params.UserId)
member, err := c.App.GetChannelMember(app.WithMaster(context.Background()), c.Params.ChannelId, c.Params.UserId)
if err != nil {
c.Err = err
return
@@ -1480,7 +1481,7 @@ func addChannelMember(c *Context, w http.ResponseWriter, r *http.Request) {
}
isNewMembership := false
if _, err = c.App.GetChannelMember(member.ChannelId, member.UserId); err != nil {
if _, err = c.App.GetChannelMember(context.Background(), member.ChannelId, member.UserId); err != nil {
if err.Id == app.MissingChannelMemberError {
isNewMembership = true
} else {
@@ -1736,7 +1737,7 @@ func channelMemberCountsByGroup(c *Context, w http.ResponseWriter, r *http.Reque
includeTimezones := r.URL.Query().Get("include_timezones") == "true"
channelMemberCounts, err := c.App.GetMemberCountsByGroup(c.Params.ChannelId, includeTimezones)
channelMemberCounts, err := c.App.GetMemberCountsByGroup(app.WithMaster(context.Background()), c.Params.ChannelId, includeTimezones)
if err != nil {
c.Err = err
return

View File

@@ -4,6 +4,7 @@
package api4
import (
"context"
"fmt"
"net/http"
"sort"
@@ -2491,7 +2492,7 @@ func TestUpdateChannelNotifyProps(t *testing.T) {
CheckNoError(t, resp)
require.True(t, pass, "should have passed")
member, err := th.App.GetChannelMember(th.BasicChannel.Id, th.BasicUser.Id)
member, err := th.App.GetChannelMember(context.Background(), th.BasicChannel.Id, th.BasicUser.Id)
require.Nil(t, err)
require.Equal(t, model.CHANNEL_NOTIFY_MENTION, member.NotifyProps[model.DESKTOP_NOTIFY_PROP], "bad update")
require.Equal(t, model.CHANNEL_MARK_UNREAD_MENTION, member.NotifyProps[model.MARK_UNREAD_NOTIFY_PROP], "bad update")

View File

@@ -4,6 +4,7 @@
package api4
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -1732,7 +1733,7 @@ func TestGetPostsForChannelAroundLastUnread(t *testing.T) {
// Set channel member's last viewed to 0.
// All returned posts are latest posts as if all previous posts were already read by the user.
channelMember, err := th.App.Srv().Store.Channel().GetMember(channelId, userId)
channelMember, err := th.App.Srv().Store.Channel().GetMember(context.Background(), channelId, userId)
require.NoError(t, err)
channelMember.LastViewedAt = 0
_, err = th.App.Srv().Store.Channel().UpdateMember(channelMember)
@@ -1753,7 +1754,7 @@ func TestGetPostsForChannelAroundLastUnread(t *testing.T) {
postIdNames[systemPost1.Id] = "system post 1"
// Set channel member's last viewed before post1.
channelMember, err = th.App.Srv().Store.Channel().GetMember(channelId, userId)
channelMember, err = th.App.Srv().Store.Channel().GetMember(context.Background(), channelId, userId)
require.NoError(t, err)
channelMember.LastViewedAt = post1.CreateAt - 1
_, err = th.App.Srv().Store.Channel().UpdateMember(channelMember)
@@ -1777,7 +1778,7 @@ func TestGetPostsForChannelAroundLastUnread(t *testing.T) {
}, posts)
// Set channel member's last viewed before post6.
channelMember, err = th.App.Srv().Store.Channel().GetMember(channelId, userId)
channelMember, err = th.App.Srv().Store.Channel().GetMember(context.Background(), channelId, userId)
require.NoError(t, err)
channelMember.LastViewedAt = post6.CreateAt - 1
_, err = th.App.Srv().Store.Channel().UpdateMember(channelMember)
@@ -1804,7 +1805,7 @@ func TestGetPostsForChannelAroundLastUnread(t *testing.T) {
}, posts)
// Set channel member's last viewed before post10.
channelMember, err = th.App.Srv().Store.Channel().GetMember(channelId, userId)
channelMember, err = th.App.Srv().Store.Channel().GetMember(context.Background(), channelId, userId)
require.NoError(t, err)
channelMember.LastViewedAt = post10.CreateAt - 1
_, err = th.App.Srv().Store.Channel().UpdateMember(channelMember)
@@ -1829,7 +1830,7 @@ func TestGetPostsForChannelAroundLastUnread(t *testing.T) {
}, posts)
// Set channel member's last viewed equal to post10.
channelMember, err = th.App.Srv().Store.Channel().GetMember(channelId, userId)
channelMember, err = th.App.Srv().Store.Channel().GetMember(context.Background(), channelId, userId)
require.NoError(t, err)
channelMember.LastViewedAt = post10.CreateAt
_, err = th.App.Srv().Store.Channel().UpdateMember(channelMember)
@@ -1869,7 +1870,7 @@ func TestGetPostsForChannelAroundLastUnread(t *testing.T) {
postIdNames[post12.Id] = "post12 (reply to post4)"
postIdNames[post13.Id] = "post13"
channelMember, err = th.App.Srv().Store.Channel().GetMember(channelId, userId)
channelMember, err = th.App.Srv().Store.Channel().GetMember(context.Background(), channelId, userId)
require.NoError(t, err)
channelMember.LastViewedAt = post12.CreateAt - 1
_, err = th.App.Srv().Store.Channel().UpdateMember(channelMember)

View File

@@ -543,7 +543,7 @@ type AppIface interface {
GetChannelByNameForTeamName(channelName, teamName string, includeDeleted bool) (*model.Channel, *model.AppError)
GetChannelCounts(teamID string, userID string) (*model.ChannelCounts, *model.AppError)
GetChannelGuestCount(channelID string) (int64, *model.AppError)
GetChannelMember(channelID string, userID string) (*model.ChannelMember, *model.AppError)
GetChannelMember(ctx context.Context, channelID string, userID string) (*model.ChannelMember, *model.AppError)
GetChannelMemberCount(channelID string) (int64, *model.AppError)
GetChannelMembersByIds(channelID string, userIDs []string) (*model.ChannelMembers, *model.AppError)
GetChannelMembersForUser(teamID string, userID string) (*model.ChannelMembers, *model.AppError)
@@ -610,7 +610,7 @@ type AppIface interface {
GetLatestTermsOfService() (*model.TermsOfService, *model.AppError)
GetLogs(page, perPage int) ([]string, *model.AppError)
GetLogsSkipSend(page, perPage int) ([]string, *model.AppError)
GetMemberCountsByGroup(channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError)
GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError)
GetMessageForNotification(post *model.Post, translateFunc i18n.TranslateFunc) string
GetMultipleEmojiByName(names []string) ([]*model.Emoji, *model.AppError)
GetNewUsersForTeamPage(teamID string, page, perPage int, asAdmin bool, viewRestrictions *model.ViewUsersRestrictions) ([]*model.User, *model.AppError)
@@ -1044,6 +1044,5 @@ type AppIface interface {
VerifyEmailFromToken(userSuppliedTokenString string) *model.AppError
VerifyUserEmail(userID, email string) *model.AppError
ViewChannel(view *model.ChannelView, userID string, currentSessionId string) (map[string]int64, *model.AppError)
WaitForChannelMembership(channelID string, userID string)
WriteFile(fr io.Reader, path string) (int64, *model.AppError)
}

View File

@@ -4,6 +4,7 @@
package app
import (
"context"
"net/http"
"strings"
@@ -175,7 +176,7 @@ func (a *App) HasPermissionToChannel(askingUserId string, channelID string, perm
return false
}
channelMember, err := a.GetChannelMember(channelID, askingUserId)
channelMember, err := a.GetChannelMember(context.Background(), channelID, askingUserId)
if err == nil {
roles := channelMember.GetRoles()
if a.RolesGrantPermission(roles, permission.Id) {

View File

@@ -9,7 +9,6 @@ import (
"fmt"
"net/http"
"strings"
"time"
"github.com/mattermost/mattermost-server/v5/model"
"github.com/mattermost/mattermost-server/v5/plugin"
@@ -341,7 +340,6 @@ func (a *App) GetOrCreateDirectChannel(userID, otherUserID string) (*model.Chann
return nil, err
}
a.WaitForChannelMembership(channel.Id, userID)
a.handleCreationEvent(userID, otherUserID, channel)
return channel, nil
}
@@ -364,7 +362,6 @@ func (a *App) getOrCreateDirectChannelWithUser(user, otherUser *model.User) (*mo
return nil, err
}
a.WaitForChannelMembership(channel.Id, user.Id)
a.handleCreationEvent(user.Id, otherUser.Id, channel)
return channel, nil
}
@@ -466,34 +463,6 @@ func (a *App) createDirectChannelWithUser(user, otherUser *model.User) (*model.C
return channel, nil
}
func (a *App) WaitForChannelMembership(channelID string, userID string) {
if len(a.Config().SqlSettings.DataSourceReplicas) == 0 {
return
}
now := model.GetMillis()
for model.GetMillis()-now < 12000 {
time.Sleep(100 * time.Millisecond)
_, err := a.Srv().Store.Channel().GetMember(channelID, userID)
// If the membership was found then return
if err == nil {
return
}
// If we received an error, but it wasn't a missing channel member then return
var nfErr *store.ErrNotFound
if !errors.As(err, &nfErr) {
return
}
}
mlog.Error("WaitForChannelMembership giving up", mlog.String("channel_id", channelID), mlog.String("user_id", userID))
}
func (a *App) CreateGroupChannel(userIDs []string, creatorId string) (*model.Channel, *model.AppError) {
channel, err := a.createGroupChannel(userIDs)
if err != nil {
@@ -504,10 +473,6 @@ func (a *App) CreateGroupChannel(userIDs []string, creatorId string) (*model.Cha
}
for _, userID := range userIDs {
if userID == creatorId {
a.WaitForChannelMembership(channel.Id, creatorId)
}
a.InvalidateCacheForUser(userID)
}
@@ -1087,7 +1052,7 @@ func buildChannelModerations(channelType string, memberRole *model.Role, guestRo
func (a *App) UpdateChannelMemberRoles(channelID string, userID string, newRoles string) (*model.ChannelMember, *model.AppError) {
var member *model.ChannelMember
var err *model.AppError
if member, err = a.GetChannelMember(channelID, userID); err != nil {
if member, err = a.GetChannelMember(context.Background(), channelID, userID); err != nil {
return nil, err
}
@@ -1144,7 +1109,7 @@ func (a *App) UpdateChannelMemberRoles(channelID string, userID string, newRoles
}
func (a *App) UpdateChannelMemberSchemeRoles(channelID string, userID string, isSchemeGuest bool, isSchemeUser bool, isSchemeAdmin bool) (*model.ChannelMember, *model.AppError) {
member, err := a.GetChannelMember(channelID, userID)
member, err := a.GetChannelMember(context.Background(), channelID, userID)
if err != nil {
return nil, err
}
@@ -1168,7 +1133,7 @@ func (a *App) UpdateChannelMemberSchemeRoles(channelID string, userID string, is
func (a *App) UpdateChannelMemberNotifyProps(data map[string]string, channelID string, userID string) (*model.ChannelMember, *model.AppError) {
var member *model.ChannelMember
var err *model.AppError
if member, err = a.GetChannelMember(channelID, userID); err != nil {
if member, err = a.GetChannelMember(context.Background(), channelID, userID); err != nil {
return nil, err
}
@@ -1334,7 +1299,7 @@ func (a *App) addUserToChannel(user *model.User, channel *model.Channel) (*model
return nil, model.NewAppError("AddUserToChannel", "api.channel.add_user_to_channel.type.app_error", nil, "", http.StatusBadRequest)
}
channelMember, nErr := a.Srv().Store.Channel().GetMember(channel.Id, user.Id)
channelMember, nErr := a.Srv().Store.Channel().GetMember(context.Background(), channel.Id, user.Id)
if nErr != nil {
var nfErr *store.ErrNotFound
if !errors.As(nErr, &nfErr) {
@@ -1375,7 +1340,6 @@ func (a *App) addUserToChannel(user *model.User, channel *model.Channel) (*model
if nErr != nil {
return nil, model.NewAppError("AddUserToChannel", "api.channel.add_user.to.channel.failed.app_error", nil, fmt.Sprintf("failed to add member: user_id: %s, channel_id:%s", user.Id, channel.Id), http.StatusInternalServerError)
}
a.WaitForChannelMembership(channel.Id, user.Id)
if nErr := a.Srv().Store.ChannelMemberHistory().LogJoinEvent(user.Id, channel.Id, model.GetMillis()); nErr != nil {
return nil, model.NewAppError("AddUserToChannel", "app.channel_member_history.log_join_event.internal_error", nil, nErr.Error(), http.StatusInternalServerError)
@@ -1417,7 +1381,7 @@ func (a *App) AddUserToChannel(user *model.User, channel *model.Channel) (*model
}
func (a *App) AddChannelMember(userID string, channel *model.Channel, userRequestorId string, postRootId string) (*model.ChannelMember, *model.AppError) {
if member, err := a.Srv().Store.Channel().GetMember(channel.Id, userID); err != nil {
if member, err := a.Srv().Store.Channel().GetMember(context.Background(), channel.Id, userID); err != nil {
var nfErr *store.ErrNotFound
if !errors.As(err, &nfErr) {
return nil, model.NewAppError("AddChannelMember", "app.channel.get_member.app_error", nil, err.Error(), http.StatusInternalServerError)
@@ -1782,8 +1746,8 @@ func (a *App) GetPrivateChannelsForTeam(teamID string, offset int, limit int) (*
return list, nil
}
func (a *App) GetChannelMember(channelID string, userID string) (*model.ChannelMember, *model.AppError) {
channelMember, err := a.Srv().Store.Channel().GetMember(channelID, userID)
func (a *App) GetChannelMember(ctx context.Context, channelID string, userID string) (*model.ChannelMember, *model.AppError) {
channelMember, err := a.Srv().Store.Channel().GetMember(ctx, channelID, userID)
if err != nil {
var nfErr *store.ErrNotFound
switch {
@@ -1921,7 +1885,7 @@ func (a *App) JoinChannel(channel *model.Channel, userID string) *model.AppError
close(userChan)
}()
go func() {
member, err := a.Srv().Store.Channel().GetMember(channel.Id, userID)
member, err := a.Srv().Store.Channel().GetMember(context.Background(), channel.Id, userID)
memberChan <- store.StoreResult{Data: member, NErr: err}
close(memberChan)
}()
@@ -2215,7 +2179,7 @@ func (a *App) removeUserFromChannel(userIDToRemove string, removerUserId string,
}
}
cm, err := a.GetChannelMember(channel.Id, userIDToRemove)
cm, err := a.GetChannelMember(context.Background(), channel.Id, userIDToRemove)
if err != nil {
return err
}
@@ -2558,7 +2522,7 @@ func (a *App) MarkChannelsAsViewed(channelIDs []string, userID string, currentSe
continue
}
member, err := a.Srv().Store.Channel().GetMember(channelID, userID)
member, err := a.Srv().Store.Channel().GetMember(context.Background(), channelID, userID)
if err != nil {
mlog.Warn("Failed to get membership", mlog.Err(err))
continue
@@ -2846,7 +2810,7 @@ func (a *App) GetPinnedPosts(channelID string) (*model.PostList, *model.AppError
}
func (a *App) ToggleMuteChannel(channelID, userID string) (*model.ChannelMember, *model.AppError) {
member, nErr := a.Srv().Store.Channel().GetMember(channelID, userID)
member, nErr := a.Srv().Store.Channel().GetMember(context.Background(), channelID, userID)
if nErr != nil {
var appErr *model.AppError
var nfErr *store.ErrNotFound
@@ -3029,8 +2993,8 @@ func (a *App) ClearChannelMembersCache(channelID string) {
}
}
func (a *App) GetMemberCountsByGroup(channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) {
channelMemberCounts, err := a.Srv().Store.Channel().GetMemberCountsByGroup(channelID, includeTimezones)
func (a *App) GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) {
channelMemberCounts, err := a.Srv().Store.Channel().GetMemberCountsByGroup(ctx, channelID, includeTimezones)
if err != nil {
return nil, model.NewAppError("GetMemberCountsByGroup", "app.channel.get_member_count.app_error", nil, err.Error(), http.StatusInternalServerError)
}

View File

@@ -4,6 +4,7 @@
package app
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
@@ -168,10 +169,10 @@ func TestUpdateSidebarCategories(t *testing.T) {
assert.True(t, updated[0].Muted)
// Confirm that the channels are now muted
member1, err := th.App.GetChannelMember(channel1.Id, th.BasicUser.Id)
member1, err := th.App.GetChannelMember(context.Background(), channel1.Id, th.BasicUser.Id)
require.Nil(t, err)
assert.True(t, member1.IsChannelMuted())
member2, err := th.App.GetChannelMember(channel2.Id, th.BasicUser.Id)
member2, err := th.App.GetChannelMember(context.Background(), channel2.Id, th.BasicUser.Id)
require.Nil(t, err)
assert.True(t, member2.IsChannelMuted())
@@ -189,10 +190,10 @@ func TestUpdateSidebarCategories(t *testing.T) {
assert.False(t, updated[0].Muted)
// Confirm that the channels are now unmuted
member1, err = th.App.GetChannelMember(channel1.Id, th.BasicUser.Id)
member1, err = th.App.GetChannelMember(context.Background(), channel1.Id, th.BasicUser.Id)
require.Nil(t, err)
assert.False(t, member1.IsChannelMuted())
member2, err = th.App.GetChannelMember(channel2.Id, th.BasicUser.Id)
member2, err = th.App.GetChannelMember(context.Background(), channel2.Id, th.BasicUser.Id)
require.Nil(t, err)
assert.False(t, member2.IsChannelMuted())
})
@@ -250,10 +251,10 @@ func TestUpdateSidebarCategories(t *testing.T) {
require.Nil(t, err)
// Confirm that the channels are now muted
member1, err := th.App.GetChannelMember(channel1.Id, th.BasicUser.Id)
member1, err := th.App.GetChannelMember(context.Background(), channel1.Id, th.BasicUser.Id)
require.Nil(t, err)
assert.True(t, member1.IsChannelMuted())
member2, err := th.App.GetChannelMember(channel2.Id, th.BasicUser.Id)
member2, err := th.App.GetChannelMember(context.Background(), channel2.Id, th.BasicUser.Id)
require.Nil(t, err)
assert.True(t, member2.IsChannelMuted())
@@ -279,10 +280,10 @@ func TestUpdateSidebarCategories(t *testing.T) {
require.Nil(t, err)
// Confirm that the channels are now unmuted
member1, err = th.App.GetChannelMember(channel1.Id, th.BasicUser.Id)
member1, err = th.App.GetChannelMember(context.Background(), channel1.Id, th.BasicUser.Id)
require.Nil(t, err)
assert.False(t, member1.IsChannelMuted())
member2, err = th.App.GetChannelMember(channel2.Id, th.BasicUser.Id)
member2, err = th.App.GetChannelMember(context.Background(), channel2.Id, th.BasicUser.Id)
require.Nil(t, err)
assert.False(t, member2.IsChannelMuted())
})
@@ -340,10 +341,10 @@ func TestUpdateSidebarCategories(t *testing.T) {
require.Nil(t, err)
// Confirm that the channels are still unmuted
member1, err := th.App.GetChannelMember(channel1.Id, th.BasicUser.Id)
member1, err := th.App.GetChannelMember(context.Background(), channel1.Id, th.BasicUser.Id)
require.Nil(t, err)
assert.False(t, member1.IsChannelMuted())
member2, err := th.App.GetChannelMember(channel2.Id, th.BasicUser.Id)
member2, err := th.App.GetChannelMember(context.Background(), channel2.Id, th.BasicUser.Id)
require.Nil(t, err)
assert.False(t, member2.IsChannelMuted())
@@ -375,10 +376,10 @@ func TestUpdateSidebarCategories(t *testing.T) {
require.Nil(t, err)
// Confirm that the channels are still muted
member1, err = th.App.GetChannelMember(channel1.Id, th.BasicUser.Id)
member1, err = th.App.GetChannelMember(context.Background(), channel1.Id, th.BasicUser.Id)
require.Nil(t, err)
assert.True(t, member1.IsChannelMuted())
member2, err = th.App.GetChannelMember(channel2.Id, th.BasicUser.Id)
member2, err = th.App.GetChannelMember(context.Background(), channel2.Id, th.BasicUser.Id)
require.Nil(t, err)
assert.True(t, member2.IsChannelMuted())
})
@@ -436,10 +437,10 @@ func TestUpdateSidebarCategories(t *testing.T) {
require.Nil(t, err)
// Confirm that the channels are still unmuted
member1, err := th.App.GetChannelMember(channel1.Id, th.BasicUser.Id)
member1, err := th.App.GetChannelMember(context.Background(), channel1.Id, th.BasicUser.Id)
require.Nil(t, err)
assert.False(t, member1.IsChannelMuted())
member2, err := th.App.GetChannelMember(channel2.Id, th.BasicUser.Id)
member2, err := th.App.GetChannelMember(context.Background(), channel2.Id, th.BasicUser.Id)
require.Nil(t, err)
assert.False(t, member2.IsChannelMuted())
@@ -471,10 +472,10 @@ func TestUpdateSidebarCategories(t *testing.T) {
require.Nil(t, err)
// Confirm that the channels are still muted
member1, err = th.App.GetChannelMember(channel1.Id, th.BasicUser.Id)
member1, err = th.App.GetChannelMember(context.Background(), channel1.Id, th.BasicUser.Id)
require.Nil(t, err)
assert.True(t, member1.IsChannelMuted())
member2, err = th.App.GetChannelMember(channel2.Id, th.BasicUser.Id)
member2, err = th.App.GetChannelMember(context.Background(), channel2.Id, th.BasicUser.Id)
require.Nil(t, err)
assert.True(t, member2.IsChannelMuted())
})

View File

@@ -337,7 +337,7 @@ func TestJoinDefaultChannelsExperimentalDefaultChannels(t *testing.T) {
channel, err := th.App.GetChannelByName(channelName, th.BasicTeam.Id, false)
require.Nil(t, err, "Expected nil, didn't receive nil")
member, err := th.App.GetChannelMember(channel.Id, user.Id)
member, err := th.App.GetChannelMember(context.Background(), channel.Id, user.Id)
require.NotNil(t, member, "Expected member object, got nil")
require.Nil(t, err, "Expected nil object, didn't receive nil")
@@ -526,14 +526,14 @@ func TestLeaveDefaultChannel(t *testing.T) {
err = th.App.LeaveChannel(townSquare.Id, th.BasicUser.Id)
assert.NotNil(t, err, "It should fail to remove a regular user from the default channel")
assert.Equal(t, err.Id, "api.channel.remove.default.app_error")
_, err = th.App.GetChannelMember(townSquare.Id, th.BasicUser.Id)
_, err = th.App.GetChannelMember(context.Background(), townSquare.Id, th.BasicUser.Id)
assert.Nil(t, err)
})
t.Run("Guest leaves the default channel", func(t *testing.T) {
err = th.App.LeaveChannel(townSquare.Id, guest.Id)
assert.Nil(t, err, "It should allow to remove a guest user from the default channel")
_, err = th.App.GetChannelMember(townSquare.Id, guest.Id)
_, err = th.App.GetChannelMember(context.Background(), townSquare.Id, guest.Id)
assert.NotNil(t, err)
})
}
@@ -560,7 +560,7 @@ func TestLeaveLastChannel(t *testing.T) {
t.Run("Guest leaves last channel", func(t *testing.T) {
err = th.App.LeaveChannel(th.BasicChannel.Id, guest.Id)
assert.Nil(t, err, "It should allow to remove a guest user from the default channel")
_, err = th.App.GetChannelMember(th.BasicChannel.Id, guest.Id)
_, err = th.App.GetChannelMember(context.Background(), th.BasicChannel.Id, guest.Id)
assert.NotNil(t, err)
_, err = th.App.GetTeamMember(th.BasicTeam.Id, guest.Id)
assert.Nil(t, err, "It should remove the team membership")
@@ -636,11 +636,11 @@ func TestSetChannelsMuted(t *testing.T) {
th.AddUserToChannel(th.BasicUser, channel2)
// Ensure that both channels start unmuted
member1, err := th.App.GetChannelMember(channel1.Id, th.BasicUser.Id)
member1, err := th.App.GetChannelMember(context.Background(), channel1.Id, th.BasicUser.Id)
require.Nil(t, err)
require.False(t, member1.IsChannelMuted())
member2, err := th.App.GetChannelMember(channel2.Id, th.BasicUser.Id)
member2, err := th.App.GetChannelMember(context.Background(), channel2.Id, th.BasicUser.Id)
require.Nil(t, err)
require.False(t, member2.IsChannelMuted())
@@ -651,11 +651,11 @@ func TestSetChannelsMuted(t *testing.T) {
assert.True(t, updated[1].IsChannelMuted())
// Verify that the channels are muted in the database
member1, err = th.App.GetChannelMember(channel1.Id, th.BasicUser.Id)
member1, err = th.App.GetChannelMember(context.Background(), channel1.Id, th.BasicUser.Id)
require.Nil(t, err)
require.True(t, member1.IsChannelMuted())
member2, err = th.App.GetChannelMember(channel2.Id, th.BasicUser.Id)
member2, err = th.App.GetChannelMember(context.Background(), channel2.Id, th.BasicUser.Id)
require.Nil(t, err)
require.True(t, member2.IsChannelMuted())
@@ -666,11 +666,11 @@ func TestSetChannelsMuted(t *testing.T) {
assert.False(t, updated[1].IsChannelMuted())
// Verify that the channels are muted in the database
member1, err = th.App.GetChannelMember(channel1.Id, th.BasicUser.Id)
member1, err = th.App.GetChannelMember(context.Background(), channel1.Id, th.BasicUser.Id)
require.Nil(t, err)
require.False(t, member1.IsChannelMuted())
member2, err = th.App.GetChannelMember(channel2.Id, th.BasicUser.Id)
member2, err = th.App.GetChannelMember(context.Background(), channel2.Id, th.BasicUser.Id)
require.Nil(t, err)
require.False(t, member2.IsChannelMuted())
})
@@ -1411,7 +1411,7 @@ func TestAddUserToChannel(t *testing.T) {
require.Nil(t, err)
// verify user was added as a non-admin
cm1, err := th.App.GetChannelMember(th.BasicChannel.Id, ruser1.Id)
cm1, err := th.App.GetChannelMember(context.Background(), th.BasicChannel.Id, ruser1.Id)
require.Nil(t, err)
require.False(t, cm1.SchemeAdmin)
@@ -1435,7 +1435,7 @@ func TestAddUserToChannel(t *testing.T) {
require.Nil(t, err)
// verify user was added as an admin
cm2, err := th.App.GetChannelMember(th.BasicChannel.Id, ruser2.Id)
cm2, err := th.App.GetChannelMember(context.Background(), th.BasicChannel.Id, ruser2.Id)
require.Nil(t, err)
require.True(t, cm2.SchemeAdmin)
@@ -1945,7 +1945,7 @@ func TestMarkChannelsAsViewedPanic(t *testing.T) {
mockUserStore.On("Get", context.Background(), "userID").Return(nil, model.NewAppError("SqlUserStore.Get", "app.user.get.app_error", nil, "user_id=userID", http.StatusInternalServerError))
mockChannelStore := mocks.ChannelStore{}
mockChannelStore.On("Get", "channelID", true).Return(&model.Channel{}, nil)
mockChannelStore.On("GetMember", "channelID", "userID").Return(&model.ChannelMember{
mockChannelStore.On("GetMember", context.Background(), "channelID", "userID").Return(&model.ChannelMember{
NotifyProps: model.StringMap{
model.PUSH_NOTIFY_PROP: model.CHANNEL_NOTIFY_DEFAULT,
}}, nil)
@@ -1996,9 +1996,9 @@ func TestGetMemberCountsByGroup(t *testing.T) {
ChannelMemberTimezonesCount: int64(i),
})
}
mockChannelStore.On("GetMemberCountsByGroup", "channelID", true).Return(cmc, nil)
mockChannelStore.On("GetMemberCountsByGroup", context.Background(), "channelID", true).Return(cmc, nil)
mockStore.On("Channel").Return(&mockChannelStore)
resp, err := th.App.GetMemberCountsByGroup("channelID", true)
resp, err := th.App.GetMemberCountsByGroup(context.Background(), "channelID", true)
require.Nil(t, err)
require.ElementsMatch(t, cmc, resp)
}

View File

@@ -571,7 +571,7 @@ func (a *App) HandleCommandResponsePost(command *model.Command, args *model.Comm
post.SetProps(response.Props)
if response.ChannelId != "" {
_, err := a.GetChannelMember(response.ChannelId, args.UserId)
_, err := a.GetChannelMember(context.Background(), response.ChannelId, args.UserId)
if err != nil {
err = model.NewAppError("HandleCommandResponsePost", "api.command.command_post.forbidden.app_error", nil, err.Error(), http.StatusForbidden)
return nil, err

15
app/context.go Normal file
View File

@@ -0,0 +1,15 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package app
import (
"context"
"github.com/mattermost/mattermost-server/v5/store/sqlstore"
)
// WithMaster adds the context value that master DB should be selected for this request.
func WithMaster(ctx context.Context) context.Context {
return sqlstore.WithMaster(ctx)
}

View File

@@ -4,6 +4,7 @@
package app
import (
"context"
"sync/atomic"
"testing"
"time"
@@ -89,7 +90,7 @@ func TestCheckPendingNotifications(t *testing.T) {
},
}
channelMember, err := th.App.Srv().Store.Channel().GetMember(th.BasicChannel.Id, th.BasicUser.Id)
channelMember, err := th.App.Srv().Store.Channel().GetMember(context.Background(), th.BasicChannel.Id, th.BasicUser.Id)
require.NoError(t, err)
channelMember.LastViewedAt = 9999999
_, err = th.App.Srv().Store.Channel().UpdateMember(channelMember)
@@ -110,7 +111,7 @@ func TestCheckPendingNotifications(t *testing.T) {
require.Len(t, job.pendingNotifications[th.BasicUser.Id], 1, "shouldn't have sent queued post")
// test that notifications are cleared if the user has acted
channelMember, err = th.App.Srv().Store.Channel().GetMember(th.BasicChannel.Id, th.BasicUser.Id)
channelMember, err = th.App.Srv().Store.Channel().GetMember(context.Background(), th.BasicChannel.Id, th.BasicUser.Id)
require.NoError(t, err)
channelMember.LastViewedAt = 10001000
_, err = th.App.Srv().Store.Channel().UpdateMember(channelMember)
@@ -208,7 +209,7 @@ func TestCheckPendingNotificationsDefaultInterval(t *testing.T) {
job := NewEmailBatchingJob(th.Server.EmailService, 128)
// bypasses recent user activity check
channelMember, err := th.App.Srv().Store.Channel().GetMember(th.BasicChannel.Id, th.BasicUser.Id)
channelMember, err := th.App.Srv().Store.Channel().GetMember(context.Background(), th.BasicChannel.Id, th.BasicUser.Id)
require.NoError(t, err)
channelMember.LastViewedAt = 9999000
_, err = th.App.Srv().Store.Channel().UpdateMember(channelMember)
@@ -246,7 +247,7 @@ func TestCheckPendingNotificationsCantParseInterval(t *testing.T) {
job := NewEmailBatchingJob(th.Server.EmailService, 128)
// bypasses recent user activity check
channelMember, err := th.App.Srv().Store.Channel().GetMember(th.BasicChannel.Id, th.BasicUser.Id)
channelMember, err := th.App.Srv().Store.Channel().GetMember(context.Background(), th.BasicChannel.Id, th.BasicUser.Id)
require.NoError(t, err)
channelMember.LastViewedAt = 9999000
_, err = th.App.Srv().Store.Channel().UpdateMember(channelMember)

View File

@@ -1084,7 +1084,7 @@ func TestImportImportUser(t *testing.T) {
require.Equal(t, channelMemberCount+1, cmc, "Number of channel members not as expected")
// Check channel member properties.
channelMember, appErr := th.App.GetChannelMember(channel.Id, user.Id)
channelMember, appErr := th.App.GetChannelMember(context.Background(), channel.Id, user.Id)
require.Nil(t, appErr, "Failed to get channel member from database.")
assert.Equal(t, "channel_user", channelMember.Roles)
assert.Equal(t, "default", channelMember.NotifyProps[model.DESKTOP_NOTIFY_PROP])
@@ -1119,7 +1119,7 @@ func TestImportImportUser(t *testing.T) {
require.Nil(t, appErr, "Failed to get team member from database.")
require.Equal(t, "team_user team_admin", teamMember.Roles)
channelMember, appErr = th.App.GetChannelMember(channel.Id, user.Id)
channelMember, appErr = th.App.GetChannelMember(context.Background(), channel.Id, user.Id)
require.Nil(t, appErr, "Failed to get channel member Desktop from database.")
assert.Equal(t, "channel_user channel_admin", channelMember.Roles)
assert.Equal(t, model.USER_NOTIFY_MENTION, channelMember.NotifyProps[model.DESKTOP_NOTIFY_PROP])
@@ -1435,7 +1435,7 @@ func TestImportImportUser(t *testing.T) {
assert.False(t, teamMember.SchemeGuest)
assert.Equal(t, "", teamMember.ExplicitRoles)
channelMember, appErr = th.App.GetChannelMember(channel.Id, user.Id)
channelMember, appErr = th.App.GetChannelMember(context.Background(), channel.Id, user.Id)
require.Nil(t, appErr, "Failed to get the channel member")
assert.True(t, channelMember.SchemeAdmin)
@@ -1477,7 +1477,7 @@ func TestImportImportUser(t *testing.T) {
assert.False(t, teamMember.SchemeGuest)
assert.Equal(t, "", teamMember.ExplicitRoles)
channelMember, appErr = th.App.GetChannelMember(channel.Id, user.Id)
channelMember, appErr = th.App.GetChannelMember(context.Background(), channel.Id, user.Id)
require.Nil(t, appErr, "Failed to get the channel member")
assert.False(t, teamMember.SchemeAdmin)
@@ -1519,7 +1519,7 @@ func TestImportImportUser(t *testing.T) {
assert.True(t, teamMember.SchemeGuest)
assert.Equal(t, "", teamMember.ExplicitRoles)
channelMember, appErr = th.App.GetChannelMember(channel.Id, user.Id)
channelMember, appErr = th.App.GetChannelMember(context.Background(), channel.Id, user.Id)
require.Nil(t, appErr, "Failed to get the channel member")
assert.False(t, teamMember.SchemeAdmin)

View File

@@ -4697,7 +4697,7 @@ func (a *OpenTracingAppLayer) GetChannelGuestCount(channelID string) (int64, *mo
return resultVar0, resultVar1
}
func (a *OpenTracingAppLayer) GetChannelMember(channelID string, userID string) (*model.ChannelMember, *model.AppError) {
func (a *OpenTracingAppLayer) GetChannelMember(ctx context.Context, channelID string, userID string) (*model.ChannelMember, *model.AppError) {
origCtx := a.ctx
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetChannelMember")
@@ -4709,7 +4709,7 @@ func (a *OpenTracingAppLayer) GetChannelMember(channelID string, userID string)
}()
defer span.Finish()
resultVar0, resultVar1 := a.app.GetChannelMember(channelID, userID)
resultVar0, resultVar1 := a.app.GetChannelMember(ctx, channelID, userID)
if resultVar1 != nil {
span.LogFields(spanlog.Error(resultVar1))
@@ -6366,7 +6366,7 @@ func (a *OpenTracingAppLayer) GetMarketplacePlugins(filter *model.MarketplacePlu
return resultVar0, resultVar1
}
func (a *OpenTracingAppLayer) GetMemberCountsByGroup(channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) {
func (a *OpenTracingAppLayer) GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) {
origCtx := a.ctx
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetMemberCountsByGroup")
@@ -6378,7 +6378,7 @@ func (a *OpenTracingAppLayer) GetMemberCountsByGroup(channelID string, includeTi
}()
defer span.Finish()
resultVar0, resultVar1 := a.app.GetMemberCountsByGroup(channelID, includeTimezones)
resultVar0, resultVar1 := a.app.GetMemberCountsByGroup(ctx, channelID, includeTimezones)
if resultVar1 != nil {
span.LogFields(spanlog.Error(resultVar1))
@@ -16240,21 +16240,6 @@ func (a *OpenTracingAppLayer) ViewChannel(view *model.ChannelView, userID string
return resultVar0, resultVar1
}
func (a *OpenTracingAppLayer) WaitForChannelMembership(channelID string, userID string) {
origCtx := a.ctx
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.WaitForChannelMembership")
a.ctx = newCtx
a.app.Srv().Store.SetContext(newCtx)
defer func() {
a.app.Srv().Store.SetContext(origCtx)
a.ctx = origCtx
}()
defer span.Finish()
a.app.WaitForChannelMembership(channelID, userID)
}
func (a *OpenTracingAppLayer) WriteFile(fr io.Reader, path string) (int64, *model.AppError) {
origCtx := a.ctx
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.WriteFile")

View File

@@ -5,6 +5,7 @@ package app
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@@ -501,7 +502,7 @@ func (api *PluginAPI) AddUserToChannel(channelID, userID, asUserId string) (*mod
}
func (api *PluginAPI) GetChannelMember(channelID, userID string) (*model.ChannelMember, *model.AppError) {
return api.app.GetChannelMember(channelID, userID)
return api.app.GetChannelMember(context.Background(), channelID, userID)
}
func (api *PluginAPI) GetChannelMembers(channelID string, page, perPage int) (*model.ChannelMembers, *model.AppError) {

View File

@@ -974,7 +974,7 @@ func (a *App) AddCursorIdsForPostList(originalList *model.PostList, afterPost, b
func (a *App) GetPostsForChannelAroundLastUnread(channelID, userID string, limitBefore, limitAfter int, skipFetchThreads bool, collapsedThreads, collapsedThreadsExtended bool) (*model.PostList, *model.AppError) {
var member *model.ChannelMember
var err *model.AppError
if member, err = a.GetChannelMember(channelID, userID); err != nil {
if member, err = a.GetChannelMember(context.Background(), channelID, userID); err != nil {
return nil, err
} else if member.LastViewedAt == 0 {
return model.NewPostList(), nil
@@ -1419,7 +1419,7 @@ func (a *App) countMentionsFromPost(user *model.User, post *model.Post) (int, *m
return count, nil
}
channelMember, err := a.GetChannelMember(channel.Id, user.Id)
channelMember, err := a.GetChannelMember(context.Background(), channel.Id, user.Id)
if err != nil {
return 0, err
}

View File

@@ -4,6 +4,7 @@
package app
import (
"context"
"fmt"
"net/http"
"os"
@@ -855,14 +856,14 @@ func TestCreatePostAsUser(t *testing.T) {
UserId: th.BasicUser.Id,
}
channelMemberBefore, err := th.App.Srv().Store.Channel().GetMember(th.BasicChannel.Id, th.BasicUser.Id)
channelMemberBefore, err := th.App.Srv().Store.Channel().GetMember(context.Background(), th.BasicChannel.Id, th.BasicUser.Id)
require.NoError(t, err)
time.Sleep(1 * time.Millisecond)
_, appErr := th.App.CreatePostAsUser(post, "", true)
require.Nil(t, appErr)
channelMemberAfter, err := th.App.Srv().Store.Channel().GetMember(th.BasicChannel.Id, th.BasicUser.Id)
channelMemberAfter, err := th.App.Srv().Store.Channel().GetMember(context.Background(), th.BasicChannel.Id, th.BasicUser.Id)
require.NoError(t, err)
require.Greater(t, channelMemberAfter.LastViewedAt, channelMemberBefore.LastViewedAt)
@@ -879,14 +880,14 @@ func TestCreatePostAsUser(t *testing.T) {
}
post.AddProp("from_webhook", "true")
channelMemberBefore, err := th.App.Srv().Store.Channel().GetMember(th.BasicChannel.Id, th.BasicUser.Id)
channelMemberBefore, err := th.App.Srv().Store.Channel().GetMember(context.Background(), th.BasicChannel.Id, th.BasicUser.Id)
require.NoError(t, err)
time.Sleep(1 * time.Millisecond)
_, appErr := th.App.CreatePostAsUser(post, "", true)
require.Nil(t, appErr)
channelMemberAfter, err := th.App.Srv().Store.Channel().GetMember(th.BasicChannel.Id, th.BasicUser.Id)
channelMemberAfter, err := th.App.Srv().Store.Channel().GetMember(context.Background(), th.BasicChannel.Id, th.BasicUser.Id)
require.NoError(t, err)
require.Equal(t, channelMemberAfter.LastViewedAt, channelMemberBefore.LastViewedAt)
@@ -910,14 +911,14 @@ func TestCreatePostAsUser(t *testing.T) {
UserId: bot.UserId,
}
channelMemberBefore, nErr := th.App.Srv().Store.Channel().GetMember(th.BasicChannel.Id, th.BasicUser.Id)
channelMemberBefore, nErr := th.App.Srv().Store.Channel().GetMember(context.Background(), th.BasicChannel.Id, th.BasicUser.Id)
require.NoError(t, nErr)
time.Sleep(1 * time.Millisecond)
_, appErr = th.App.CreatePostAsUser(post, "", true)
require.Nil(t, appErr)
channelMemberAfter, nErr := th.App.Srv().Store.Channel().GetMember(th.BasicChannel.Id, th.BasicUser.Id)
channelMemberAfter, nErr := th.App.Srv().Store.Channel().GetMember(context.Background(), th.BasicChannel.Id, th.BasicUser.Id)
require.NoError(t, nErr)
require.Equal(t, channelMemberAfter.LastViewedAt, channelMemberBefore.LastViewedAt)

View File

@@ -4,6 +4,8 @@
package slashcommands
import (
"context"
"github.com/mattermost/mattermost-server/v5/app"
"github.com/mattermost/mattermost-server/v5/model"
"github.com/mattermost/mattermost-server/v5/shared/i18n"
@@ -63,7 +65,7 @@ func (*HeaderProvider) DoCommand(a *app.App, args *model.CommandArgs, message st
case model.CHANNEL_GROUP, model.CHANNEL_DIRECT:
// Modifying the header is not linked to any specific permission for group/dm channels, so just check for membership.
var channelMember *model.ChannelMember
channelMember, err = a.GetChannelMember(args.ChannelId, args.UserId)
channelMember, err = a.GetChannelMember(context.Background(), args.ChannelId, args.UserId)
if err != nil || channelMember == nil {
return &model.CommandResponse{
Text: args.T("api.command_channel_header.permission.app_error"),

View File

@@ -4,6 +4,7 @@
package slashcommands
import (
"context"
"strings"
"github.com/mattermost/mattermost-server/v5/app"
@@ -103,7 +104,7 @@ func (*InviteProvider) DoCommand(a *app.App, args *model.CommandArgs, message st
}
case model.CHANNEL_PRIVATE:
if !a.HasPermissionToChannel(args.UserId, channelToJoin.Id, model.PERMISSION_MANAGE_PRIVATE_CHANNEL_MEMBERS) {
if _, err = a.GetChannelMember(channelToJoin.Id, args.UserId); err == nil {
if _, err = a.GetChannelMember(context.Background(), channelToJoin.Id, args.UserId); err == nil {
// User doing the inviting is a member of the channel.
return &model.CommandResponse{
Text: args.T("api.command_invite.permission.app_error", map[string]interface{}{
@@ -129,7 +130,7 @@ func (*InviteProvider) DoCommand(a *app.App, args *model.CommandArgs, message st
}
// Check if user is already in the channel
_, err = a.GetChannelMember(channelToJoin.Id, userProfile.Id)
_, err = a.GetChannelMember(context.Background(), channelToJoin.Id, userProfile.Id)
if err == nil {
return &model.CommandResponse{
Text: args.T("api.command_invite.user_already_in_channel.app_error", map[string]interface{}{

View File

@@ -4,6 +4,7 @@
package slashcommands
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
@@ -80,7 +81,7 @@ func TestLeaveProviderDoCommand(t *testing.T) {
assert.Equal(t, args.SiteURL+"/"+th.BasicTeam.Name+"/channels/"+model.DEFAULT_CHANNEL, actual.GotoLocation)
assert.Equal(t, "", actual.ResponseType)
_, err = th.App.GetChannelMember(publicChannel.Id, th.BasicUser.Id)
_, err = th.App.GetChannelMember(context.Background(), publicChannel.Id, th.BasicUser.Id)
assert.NotNil(t, err)
assert.NotNil(t, err.Id, "app.channel.get_member.missing.app_error")
})
@@ -122,7 +123,7 @@ func TestLeaveProviderDoCommand(t *testing.T) {
assert.Equal(t, args.SiteURL+"/"+th.BasicTeam.Name+"/channels/"+publicChannel.Name, actual.GotoLocation)
assert.Equal(t, "", actual.ResponseType)
_, err = th.App.GetChannelMember(defaultChannel.Id, guest.Id)
_, err = th.App.GetChannelMember(context.Background(), defaultChannel.Id, guest.Id)
assert.NotNil(t, err)
assert.NotNil(t, err.Id, "app.channel.get_member.missing.app_error")
})
@@ -140,7 +141,7 @@ func TestLeaveProviderDoCommand(t *testing.T) {
assert.Equal(t, args.SiteURL+"/", actual.GotoLocation)
assert.Equal(t, "", actual.ResponseType)
_, err = th.App.GetChannelMember(publicChannel.Id, guest.Id)
_, err = th.App.GetChannelMember(context.Background(), publicChannel.Id, guest.Id)
assert.NotNil(t, err)
assert.NotNil(t, err.Id, "app.channel.get_member.missing.app_error")
})

View File

@@ -4,6 +4,7 @@
package slashcommands
import (
"context"
"testing"
"time"
@@ -22,7 +23,7 @@ func TestMuteCommandNoChannel(t *testing.T) {
}
channel1 := th.BasicChannel
channel1M, channel1MError := th.App.GetChannelMember(channel1.Id, th.BasicUser.Id)
channel1M, channel1MError := th.App.GetChannelMember(context.Background(), channel1.Id, th.BasicUser.Id)
assert.Nil(t, channel1MError, "User is not a member of channel 1")
assert.NotEqual(
@@ -45,7 +46,7 @@ func TestMuteCommandNoArgs(t *testing.T) {
defer th.tearDown()
channel1 := th.BasicChannel
channel1M, _ := th.App.GetChannelMember(channel1.Id, th.BasicUser.Id)
channel1M, _ := th.App.GetChannelMember(context.Background(), channel1.Id, th.BasicUser.Id)
assert.Equal(t, model.CHANNEL_NOTIFY_ALL, channel1M.NotifyProps[model.MARK_UNREAD_NOTIFY_PROP])
@@ -87,7 +88,7 @@ func TestMuteCommandSpecificChannel(t *testing.T) {
CreatorId: th.BasicUser.Id,
}, true)
channel2M, _ := th.App.GetChannelMember(channel2.Id, th.BasicUser.Id)
channel2M, _ := th.App.GetChannelMember(context.Background(), channel2.Id, th.BasicUser.Id)
assert.Equal(t, model.CHANNEL_NOTIFY_ALL, channel2M.NotifyProps[model.MARK_UNREAD_NOTIFY_PROP])
@@ -100,7 +101,7 @@ func TestMuteCommandSpecificChannel(t *testing.T) {
UserId: th.BasicUser.Id,
}, channel2.Name)
assert.Equal(t, "api.command_mute.success_mute", resp.Text)
channel2M, _ = th.App.GetChannelMember(channel2.Id, th.BasicUser.Id)
channel2M, _ = th.App.GetChannelMember(context.Background(), channel2.Id, th.BasicUser.Id)
assert.Equal(t, model.CHANNEL_NOTIFY_MENTION, channel2M.NotifyProps[model.MARK_UNREAD_NOTIFY_PROP])
// Now unmute the channel
@@ -111,7 +112,7 @@ func TestMuteCommandSpecificChannel(t *testing.T) {
}, "~"+channel2.Name)
assert.Equal(t, "api.command_mute.success_unmute", resp.Text)
channel2M, _ = th.App.GetChannelMember(channel2.Id, th.BasicUser.Id)
channel2M, _ = th.App.GetChannelMember(context.Background(), channel2.Id, th.BasicUser.Id)
assert.Equal(t, model.CHANNEL_NOTIFY_ALL, channel2M.NotifyProps[model.MARK_UNREAD_NOTIFY_PROP])
}
@@ -173,7 +174,7 @@ func TestMuteCommandDMChannel(t *testing.T) {
}
channel2, _ := th.App.GetOrCreateDirectChannel(th.BasicUser.Id, th.BasicUser2.Id)
channel2M, _ := th.App.GetChannelMember(channel2.Id, th.BasicUser.Id)
channel2M, _ := th.App.GetChannelMember(context.Background(), channel2.Id, th.BasicUser.Id)
assert.Equal(t, model.CHANNEL_NOTIFY_ALL, channel2M.NotifyProps[model.MARK_UNREAD_NOTIFY_PROP])
@@ -187,7 +188,7 @@ func TestMuteCommandDMChannel(t *testing.T) {
}, "")
assert.Equal(t, "api.command_mute.success_mute_direct_msg", resp.Text)
time.Sleep(time.Millisecond)
channel2M, _ = th.App.GetChannelMember(channel2.Id, th.BasicUser.Id)
channel2M, _ = th.App.GetChannelMember(context.Background(), channel2.Id, th.BasicUser.Id)
assert.Equal(t, model.CHANNEL_NOTIFY_MENTION, channel2M.NotifyProps[model.MARK_UNREAD_NOTIFY_PROP])
// Now unmute the channel
@@ -199,6 +200,6 @@ func TestMuteCommandDMChannel(t *testing.T) {
assert.Equal(t, "api.command_mute.success_unmute_direct_msg", resp.Text)
time.Sleep(time.Millisecond)
channel2M, _ = th.App.GetChannelMember(channel2.Id, th.BasicUser.Id)
channel2M, _ = th.App.GetChannelMember(context.Background(), channel2.Id, th.BasicUser.Id)
assert.Equal(t, model.CHANNEL_NOTIFY_ALL, channel2M.NotifyProps[model.MARK_UNREAD_NOTIFY_PROP])
}

View File

@@ -4,6 +4,7 @@
package slashcommands
import (
"context"
"strings"
"github.com/mattermost/mattermost-server/v5/app"
@@ -122,7 +123,7 @@ func doCommand(a *app.App, args *model.CommandArgs, message string) *model.Comma
}
}
_, err = a.GetChannelMember(args.ChannelId, userProfile.Id)
_, err = a.GetChannelMember(context.Background(), args.ChannelId, userProfile.Id)
if err != nil {
nameFormat := *a.Config().TeamSettings.TeammateNameDisplay
return &model.CommandResponse{

View File

@@ -4,6 +4,7 @@
package app
import (
"context"
"testing"
"github.com/stretchr/testify/require"
@@ -113,7 +114,7 @@ func TestCreateDefaultMemberships(t *testing.T) {
if err != nil {
t.Errorf("error retrieving team member: %s", err.Error())
}
_, err = th.App.GetChannelMember(practiceChannel.Id, singer1.Id)
_, err = th.App.GetChannelMember(context.Background(), practiceChannel.Id, singer1.Id)
if err != nil {
t.Errorf("error retrieving channel member: %s", err.Error())
}
@@ -142,7 +143,7 @@ func TestCreateDefaultMemberships(t *testing.T) {
t.Errorf("wrong error: %s", err.Id)
}
_, err = th.App.GetChannelMember(experimentsChannel.Id, scientist1.Id)
_, err = th.App.GetChannelMember(context.Background(), experimentsChannel.Id, scientist1.Id)
if err.Id != "app.channel.get_member.missing.app_error" {
t.Errorf("wrong error: %s", err.Id)
}
@@ -184,7 +185,7 @@ func TestCreateDefaultMemberships(t *testing.T) {
t.Errorf("error retrieving team member: %s", err.Error())
}
_, err = th.App.GetChannelMember(experimentsChannel.Id, scientist1.Id)
_, err = th.App.GetChannelMember(context.Background(), experimentsChannel.Id, scientist1.Id)
if err.Id != "app.channel.get_member.missing.app_error" {
t.Errorf("wrong error: %s", err.Id)
}
@@ -255,7 +256,7 @@ func TestCreateDefaultMemberships(t *testing.T) {
t.Error("expected team member to remain deleted")
}
_, err = th.App.GetChannelMember(practiceChannel.Id, singer1.Id)
_, err = th.App.GetChannelMember(context.Background(), practiceChannel.Id, singer1.Id)
if err == nil {
t.Error("Expected channel member to remain deleted")
}
@@ -308,7 +309,7 @@ func TestCreateDefaultMemberships(t *testing.T) {
t.Errorf("failed to populate syncables: %s", pErr.Error())
}
_, err = th.App.GetChannelMember(experimentsChannel.Id, scientist1.Id)
_, err = th.App.GetChannelMember(context.Background(), experimentsChannel.Id, scientist1.Id)
if err == nil {
t.Error("Expected channel member to remain deleted")
}
@@ -325,7 +326,7 @@ func TestCreateDefaultMemberships(t *testing.T) {
}
// Channel member is re-added.
_, err = th.App.GetChannelMember(experimentsChannel.Id, scientist1.Id)
_, err = th.App.GetChannelMember(context.Background(), experimentsChannel.Id, scientist1.Id)
if err != nil {
t.Errorf("expected channel member: %s", err.Error())
}
@@ -501,7 +502,7 @@ func TestSyncSyncableRoles(t *testing.T) {
require.Nil(t, err)
require.True(t, tm.SchemeAdmin)
cm, err := th.App.GetChannelMember(channel.Id, user.Id)
cm, err := th.App.GetChannelMember(context.Background(), channel.Id, user.Id)
require.Nil(t, err)
require.True(t, cm.SchemeAdmin)
}

View File

@@ -5,6 +5,7 @@ package app
import (
"bytes"
"context"
"encoding/json"
"errors"
"image"
@@ -1145,7 +1146,7 @@ func TestPromoteGuestToUser(t *testing.T) {
assert.Nil(t, err)
assert.False(t, teamMember.SchemeGuest)
assert.True(t, teamMember.SchemeUser)
channelMember, err = th.App.GetChannelMember(th.BasicChannel.Id, guest.Id)
channelMember, err = th.App.GetChannelMember(context.Background(), th.BasicChannel.Id, guest.Id)
assert.Nil(t, err)
assert.False(t, teamMember.SchemeGuest)
assert.True(t, teamMember.SchemeUser)
@@ -1177,7 +1178,7 @@ func TestPromoteGuestToUser(t *testing.T) {
assert.Nil(t, err)
assert.False(t, teamMember.SchemeGuest)
assert.True(t, teamMember.SchemeUser)
channelMember, err = th.App.GetChannelMember(th.BasicChannel.Id, guest.Id)
channelMember, err = th.App.GetChannelMember(context.Background(), th.BasicChannel.Id, guest.Id)
assert.Nil(t, err)
assert.False(t, teamMember.SchemeGuest)
assert.True(t, teamMember.SchemeUser)
@@ -1308,7 +1309,7 @@ func TestDemoteUserToGuest(t *testing.T) {
assert.Nil(t, err)
assert.False(t, teamMember.SchemeUser)
assert.True(t, teamMember.SchemeGuest)
channelMember, err = th.App.GetChannelMember(th.BasicChannel.Id, user.Id)
channelMember, err = th.App.GetChannelMember(context.Background(), th.BasicChannel.Id, user.Id)
assert.Nil(t, err)
assert.False(t, teamMember.SchemeUser)
assert.True(t, teamMember.SchemeGuest)
@@ -1340,7 +1341,7 @@ func TestDemoteUserToGuest(t *testing.T) {
assert.Nil(t, err)
assert.False(t, teamMember.SchemeUser)
assert.True(t, teamMember.SchemeGuest)
channelMember, err = th.App.GetChannelMember(th.BasicChannel.Id, user.Id)
channelMember, err = th.App.GetChannelMember(context.Background(), th.BasicChannel.Id, user.Id)
assert.Nil(t, err)
assert.False(t, teamMember.SchemeUser)
assert.True(t, teamMember.SchemeGuest)
@@ -1370,7 +1371,7 @@ func TestDemoteUserToGuest(t *testing.T) {
th.AddUserToChannel(user, channel)
th.App.UpdateChannelMemberSchemeRoles(channel.Id, user.Id, false, true, true)
channelMember, err := th.App.GetChannelMember(channel.Id, user.Id)
channelMember, err := th.App.GetChannelMember(context.Background(), channel.Id, user.Id)
assert.Nil(t, err)
assert.True(t, channelMember.SchemeUser)
assert.True(t, channelMember.SchemeAdmin)
@@ -1389,7 +1390,7 @@ func TestDemoteUserToGuest(t *testing.T) {
assert.False(t, teamMember.SchemeAdmin)
assert.True(t, teamMember.SchemeGuest)
channelMember, err = th.App.GetChannelMember(channel.Id, user.Id)
channelMember, err = th.App.GetChannelMember(context.Background(), channel.Id, user.Id)
assert.Nil(t, err)
assert.False(t, channelMember.SchemeUser)
assert.False(t, channelMember.SchemeAdmin)

View File

@@ -1183,7 +1183,7 @@ func (s *OpenTracingLayerChannelStore) GetGuestCount(channelID string, allowFrom
return result, err
}
func (s *OpenTracingLayerChannelStore) GetMember(channelID string, userId string) (*model.ChannelMember, error) {
func (s *OpenTracingLayerChannelStore) GetMember(ctx context.Context, channelID string, userId string) (*model.ChannelMember, error) {
origCtx := s.Root.Store.Context()
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.GetMember")
s.Root.Store.SetContext(newCtx)
@@ -1192,7 +1192,7 @@ func (s *OpenTracingLayerChannelStore) GetMember(channelID string, userId string
}()
defer span.Finish()
result, err := s.ChannelStore.GetMember(channelID, userId)
result, err := s.ChannelStore.GetMember(ctx, channelID, userId)
if err != nil {
span.LogFields(spanlog.Error(err))
ext.Error.Set(span, true)
@@ -1232,7 +1232,7 @@ func (s *OpenTracingLayerChannelStore) GetMemberCountFromCache(channelID string)
return result
}
func (s *OpenTracingLayerChannelStore) GetMemberCountsByGroup(channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, error) {
func (s *OpenTracingLayerChannelStore) GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, error) {
origCtx := s.Root.Store.Context()
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.GetMemberCountsByGroup")
s.Root.Store.SetContext(newCtx)
@@ -1241,7 +1241,7 @@ func (s *OpenTracingLayerChannelStore) GetMemberCountsByGroup(channelID string,
}()
defer span.Finish()
result, err := s.ChannelStore.GetMemberCountsByGroup(channelID, includeTimezones)
result, err := s.ChannelStore.GetMemberCountsByGroup(ctx, channelID, includeTimezones)
if err != nil {
span.LogFields(spanlog.Error(err))
ext.Error.Set(span, true)

View File

@@ -1284,11 +1284,11 @@ func (s *RetryLayerChannelStore) GetGuestCount(channelID string, allowFromCache
}
func (s *RetryLayerChannelStore) GetMember(channelID string, userId string) (*model.ChannelMember, error) {
func (s *RetryLayerChannelStore) GetMember(ctx context.Context, channelID string, userId string) (*model.ChannelMember, error) {
tries := 0
for {
result, err := s.ChannelStore.GetMember(channelID, userId)
result, err := s.ChannelStore.GetMember(ctx, channelID, userId)
if err == nil {
return result, nil
}
@@ -1330,11 +1330,11 @@ func (s *RetryLayerChannelStore) GetMemberCountFromCache(channelID string) int64
}
func (s *RetryLayerChannelStore) GetMemberCountsByGroup(channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, error) {
func (s *RetryLayerChannelStore) GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, error) {
tries := 0
for {
result, err := s.ChannelStore.GetMemberCountsByGroup(channelID, includeTimezones)
result, err := s.ChannelStore.GetMemberCountsByGroup(ctx, channelID, includeTimezones)
if err == nil {
return result, nil
}

View File

@@ -4,6 +4,7 @@
package sqlstore
import (
"context"
"database/sql"
"fmt"
"sort"
@@ -1631,10 +1632,10 @@ func (s SqlChannelStore) GetChannelMembersTimezones(channelId string) ([]model.S
return dbMembersTimezone, nil
}
func (s SqlChannelStore) GetMember(channelId string, userId string) (*model.ChannelMember, error) {
func (s SqlChannelStore) GetMember(ctx context.Context, channelId string, userId string) (*model.ChannelMember, error) {
var dbMember channelMemberWithSchemeRoles
if err := s.GetReplica().SelectOne(&dbMember, ChannelMembersWithSchemeSelectQuery+"WHERE ChannelMembers.ChannelId = :ChannelId AND ChannelMembers.UserId = :UserId", map[string]interface{}{"ChannelId": channelId, "UserId": userId}); err != nil {
if err := s.DBFromContext(ctx).SelectOne(&dbMember, ChannelMembersWithSchemeSelectQuery+"WHERE ChannelMembers.ChannelId = :ChannelId AND ChannelMembers.UserId = :UserId", map[string]interface{}{"ChannelId": channelId, "UserId": userId}); err != nil {
if err == sql.ErrNoRows {
return nil, store.NewErrNotFound("ChannelMember", fmt.Sprintf("channelId=%s, userId=%s", channelId, userId))
}
@@ -1866,7 +1867,7 @@ func (s SqlChannelStore) GetMemberCount(channelId string, allowFromCache bool) (
// GetMemberCountsByGroup returns a slice of ChannelMemberCountByGroup for a given channel
// which contains the number of channel members for each group and optionally the number of unique timezones present for each group in the channel
func (s SqlChannelStore) GetMemberCountsByGroup(channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, error) {
func (s SqlChannelStore) GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, error) {
selectStr := "GroupMembers.GroupId, COUNT(ChannelMembers.UserId) AS ChannelMemberCount"
if includeTimezones {
@@ -1927,7 +1928,7 @@ func (s SqlChannelStore) GetMemberCountsByGroup(channelID string, includeTimezon
return nil, errors.Wrap(err, "channel_tosql")
}
var data []*model.ChannelMemberCountByGroup
if _, err = s.GetReplica().Select(&data, queryString, args...); err != nil {
if _, err = s.DBFromContext(ctx).Select(&data, queryString, args...); err != nil {
return nil, errors.Wrapf(err, "failed to count ChannelMembers with channelId=%s", channelID)
}

View File

@@ -171,7 +171,7 @@ type ChannelStore interface {
UpdateMember(member *model.ChannelMember) (*model.ChannelMember, error)
UpdateMultipleMembers(members []*model.ChannelMember) ([]*model.ChannelMember, error)
GetMembers(channelID string, offset, limit int) (*model.ChannelMembers, error)
GetMember(channelID string, userId string) (*model.ChannelMember, error)
GetMember(ctx context.Context, channelID string, userId string) (*model.ChannelMember, error)
GetChannelMembersTimezones(channelID string) ([]model.StringMap, error)
GetAllChannelMembersForUser(userId string, allowFromCache bool, includeDeleted bool) (map[string]string, error)
InvalidateAllChannelMembersForUser(userId string)
@@ -182,7 +182,7 @@ type ChannelStore interface {
InvalidateMemberCount(channelID string)
GetMemberCountFromCache(channelID string) int64
GetMemberCount(channelID string, allowFromCache bool) (int64, error)
GetMemberCountsByGroup(channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, error)
GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, error)
InvalidatePinnedPostCount(channelID string)
GetPinnedPostCount(channelID string, allowFromCache bool) (int64, error)
InvalidateGuestCount(channelID string)

View File

@@ -4,6 +4,7 @@
package storetest
import (
"context"
"errors"
"sort"
"strconv"
@@ -887,7 +888,7 @@ func testChannelMemberStore(t *testing.T, ss store.Store) {
c1t3, _ := ss.Channel().Get(c1.Id, false)
assert.EqualValues(t, 0, c1t3.ExtraUpdateAt, "ExtraUpdateAt should be 0")
member, _ := ss.Channel().GetMember(o1.ChannelId, o1.UserId)
member, _ := ss.Channel().GetMember(context.Background(), o1.ChannelId, o1.UserId)
require.Equal(t, o1.ChannelId, member.ChannelId, "should have go member")
_, nErr = ss.Channel().SaveMember(&o1)
@@ -4166,13 +4167,13 @@ func testChannelStoreUpdateLastViewedAt(t *testing.T, ss store.Store) {
require.NoError(t, err, "failed to update ", err)
require.Equal(t, o2.LastPostAt, times[o2.Id], "last viewed at time incorrect")
rm1, err := ss.Channel().GetMember(m1.ChannelId, m1.UserId)
rm1, err := ss.Channel().GetMember(context.Background(), m1.ChannelId, m1.UserId)
assert.NoError(t, err)
assert.Equal(t, o1.LastPostAt, rm1.LastViewedAt)
assert.Equal(t, o1.LastPostAt, rm1.LastUpdateAt)
assert.Equal(t, o1.TotalMsgCount, rm1.MsgCount)
rm2, err := ss.Channel().GetMember(m2.ChannelId, m2.UserId)
rm2, err := ss.Channel().GetMember(context.Background(), m2.ChannelId, m2.UserId)
assert.NoError(t, err)
assert.Equal(t, o2.LastPostAt, rm2.LastViewedAt)
assert.Equal(t, o2.LastPostAt, rm2.LastUpdateAt)
@@ -4278,18 +4279,18 @@ func testGetMember(t *testing.T, ss store.Store) {
_, err = ss.Channel().SaveMember(m2)
require.NoError(t, err)
_, err = ss.Channel().GetMember(model.NewId(), userId)
_, err = ss.Channel().GetMember(context.Background(), model.NewId(), userId)
require.Error(t, err, "should've failed to get member for non-existent channel")
_, err = ss.Channel().GetMember(c1.Id, model.NewId())
_, err = ss.Channel().GetMember(context.Background(), c1.Id, model.NewId())
require.Error(t, err, "should've failed to get member for non-existent user")
member, err := ss.Channel().GetMember(c1.Id, userId)
member, err := ss.Channel().GetMember(context.Background(), c1.Id, userId)
require.NoError(t, err, "shouldn't have errored when getting member", err)
require.Equal(t, c1.Id, member.ChannelId, "should've gotten member of channel 1")
require.Equal(t, userId, member.UserId, "should've have gotten member for user")
member, err = ss.Channel().GetMember(c2.Id, userId)
member, err = ss.Channel().GetMember(context.Background(), c2.Id, userId)
require.NoError(t, err, "should'nt have errored when getting member", err)
require.Equal(t, c2.Id, member.ChannelId, "should've gotten member of channel 2")
require.Equal(t, userId, member.UserId, "should've gotten member for user")
@@ -4486,7 +4487,7 @@ func testGetMemberCountsByGroup(t *testing.T, ss store.Store) {
require.NoError(t, nErr)
t.Run("empty slice for channel with no groups", func(t *testing.T) {
memberCounts, nErr = ss.Channel().GetMemberCountsByGroup(c1.Id, false)
memberCounts, nErr = ss.Channel().GetMemberCountsByGroup(context.Background(), c1.Id, false)
expectedMemberCounts := []*model.ChannelMemberCountByGroup{}
require.NoError(t, nErr)
require.Equal(t, expectedMemberCounts, memberCounts)
@@ -4496,7 +4497,7 @@ func testGetMemberCountsByGroup(t *testing.T, ss store.Store) {
require.NoError(t, err)
t.Run("returns memberCountsByGroup without timezones", func(t *testing.T) {
memberCounts, nErr = ss.Channel().GetMemberCountsByGroup(c1.Id, false)
memberCounts, nErr = ss.Channel().GetMemberCountsByGroup(context.Background(), c1.Id, false)
expectedMemberCounts := []*model.ChannelMemberCountByGroup{
{
GroupId: g1.Id,
@@ -4509,7 +4510,7 @@ func testGetMemberCountsByGroup(t *testing.T, ss store.Store) {
})
t.Run("returns memberCountsByGroup with timezones when no timezones set", func(t *testing.T) {
memberCounts, nErr = ss.Channel().GetMemberCountsByGroup(c1.Id, true)
memberCounts, nErr = ss.Channel().GetMemberCountsByGroup(context.Background(), c1.Id, true)
expectedMemberCounts := []*model.ChannelMemberCountByGroup{
{
GroupId: g1.Id,
@@ -4612,7 +4613,7 @@ func testGetMemberCountsByGroup(t *testing.T, ss store.Store) {
}
t.Run("returns memberCountsByGroup for multiple groups with lots of users without timezones", func(t *testing.T) {
memberCounts, nErr = ss.Channel().GetMemberCountsByGroup(c1.Id, false)
memberCounts, nErr = ss.Channel().GetMemberCountsByGroup(context.Background(), c1.Id, false)
expectedMemberCounts := []*model.ChannelMemberCountByGroup{
{
GroupId: g1.Id,
@@ -4635,7 +4636,7 @@ func testGetMemberCountsByGroup(t *testing.T, ss store.Store) {
})
t.Run("returns memberCountsByGroup for multiple groups with lots of users with timezones", func(t *testing.T) {
memberCounts, nErr = ss.Channel().GetMemberCountsByGroup(c1.Id, true)
memberCounts, nErr = ss.Channel().GetMemberCountsByGroup(context.Background(), c1.Id, true)
expectedMemberCounts := []*model.ChannelMemberCountByGroup{
{
GroupId: g1.Id,
@@ -6117,21 +6118,21 @@ func testChannelStoreMigrateChannelMembers(t *testing.T, ss store.Store) {
ss.Channel().ClearCaches()
cm1b, err := ss.Channel().GetMember(cm1.ChannelId, cm1.UserId)
cm1b, err := ss.Channel().GetMember(context.Background(), cm1.ChannelId, cm1.UserId)
assert.NoError(t, err)
assert.Equal(t, "", cm1b.ExplicitRoles)
assert.False(t, cm1b.SchemeGuest)
assert.True(t, cm1b.SchemeUser)
assert.True(t, cm1b.SchemeAdmin)
cm2b, err := ss.Channel().GetMember(cm2.ChannelId, cm2.UserId)
cm2b, err := ss.Channel().GetMember(context.Background(), cm2.ChannelId, cm2.UserId)
assert.NoError(t, err)
assert.Equal(t, "", cm2b.ExplicitRoles)
assert.False(t, cm1b.SchemeGuest)
assert.True(t, cm2b.SchemeUser)
assert.False(t, cm2b.SchemeAdmin)
cm3b, err := ss.Channel().GetMember(cm3.ChannelId, cm3.UserId)
cm3b, err := ss.Channel().GetMember(context.Background(), cm3.ChannelId, cm3.UserId)
assert.NoError(t, err)
assert.Equal(t, "something_else", cm3b.ExplicitRoles)
assert.False(t, cm1b.SchemeGuest)
@@ -6227,19 +6228,19 @@ func testChannelStoreClearAllCustomRoleAssignments(t *testing.T, ss store.Store)
require.NoError(t, ss.Channel().ClearAllCustomRoleAssignments())
member, err := ss.Channel().GetMember(m1.ChannelId, m1.UserId)
member, err := ss.Channel().GetMember(context.Background(), m1.ChannelId, m1.UserId)
require.NoError(t, err)
assert.Equal(t, m1.ExplicitRoles, member.Roles)
member, err = ss.Channel().GetMember(m2.ChannelId, m2.UserId)
member, err = ss.Channel().GetMember(context.Background(), m2.ChannelId, m2.UserId)
require.NoError(t, err)
assert.Equal(t, "channel_user channel_admin", member.Roles)
member, err = ss.Channel().GetMember(m3.ChannelId, m3.UserId)
member, err = ss.Channel().GetMember(context.Background(), m3.ChannelId, m3.UserId)
require.NoError(t, err)
assert.Equal(t, m3.ExplicitRoles, member.Roles)
member, err = ss.Channel().GetMember(m4.ChannelId, m4.UserId)
member, err = ss.Channel().GetMember(context.Background(), m4.ChannelId, m4.UserId)
require.NoError(t, err)
assert.Equal(t, "", member.Roles)
}

View File

@@ -5,9 +5,12 @@
package mocks
import (
context "context"
model "github.com/mattermost/mattermost-server/v5/model"
store "github.com/mattermost/mattermost-server/v5/store"
mock "github.com/stretchr/testify/mock"
store "github.com/mattermost/mattermost-server/v5/store"
)
// ChannelStore is an autogenerated mock type for the ChannelStore type
@@ -816,13 +819,13 @@ func (_m *ChannelStore) GetGuestCount(channelID string, allowFromCache bool) (in
return r0, r1
}
// GetMember provides a mock function with given fields: channelID, userId
func (_m *ChannelStore) GetMember(channelID string, userId string) (*model.ChannelMember, error) {
ret := _m.Called(channelID, userId)
// GetMember provides a mock function with given fields: ctx, channelID, userId
func (_m *ChannelStore) GetMember(ctx context.Context, channelID string, userId string) (*model.ChannelMember, error) {
ret := _m.Called(ctx, channelID, userId)
var r0 *model.ChannelMember
if rf, ok := ret.Get(0).(func(string, string) *model.ChannelMember); ok {
r0 = rf(channelID, userId)
if rf, ok := ret.Get(0).(func(context.Context, string, string) *model.ChannelMember); ok {
r0 = rf(ctx, channelID, userId)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*model.ChannelMember)
@@ -830,8 +833,8 @@ func (_m *ChannelStore) GetMember(channelID string, userId string) (*model.Chann
}
var r1 error
if rf, ok := ret.Get(1).(func(string, string) error); ok {
r1 = rf(channelID, userId)
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = rf(ctx, channelID, userId)
} else {
r1 = ret.Error(1)
}
@@ -874,13 +877,13 @@ func (_m *ChannelStore) GetMemberCountFromCache(channelID string) int64 {
return r0
}
// GetMemberCountsByGroup provides a mock function with given fields: channelID, includeTimezones
func (_m *ChannelStore) GetMemberCountsByGroup(channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, error) {
ret := _m.Called(channelID, includeTimezones)
// GetMemberCountsByGroup provides a mock function with given fields: ctx, channelID, includeTimezones
func (_m *ChannelStore) GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, error) {
ret := _m.Called(ctx, channelID, includeTimezones)
var r0 []*model.ChannelMemberCountByGroup
if rf, ok := ret.Get(0).(func(string, bool) []*model.ChannelMemberCountByGroup); ok {
r0 = rf(channelID, includeTimezones)
if rf, ok := ret.Get(0).(func(context.Context, string, bool) []*model.ChannelMemberCountByGroup); ok {
r0 = rf(ctx, channelID, includeTimezones)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*model.ChannelMemberCountByGroup)
@@ -888,8 +891,8 @@ func (_m *ChannelStore) GetMemberCountsByGroup(channelID string, includeTimezone
}
var r1 error
if rf, ok := ret.Get(1).(func(string, bool) error); ok {
r1 = rf(channelID, includeTimezones)
if rf, ok := ret.Get(1).(func(context.Context, string, bool) error); ok {
r1 = rf(ctx, channelID, includeTimezones)
} else {
r1 = ret.Error(1)
}

View File

@@ -4846,7 +4846,7 @@ func testUserStorePromoteGuestToUser(t *testing.T, ss store.Store) {
require.False(t, updatedTeamMember.SchemeGuest)
require.True(t, updatedTeamMember.SchemeUser)
updatedChannelMember, nErr := ss.Channel().GetMember(channel.Id, user.Id)
updatedChannelMember, nErr := ss.Channel().GetMember(context.Background(), channel.Id, user.Id)
require.NoError(t, nErr)
require.False(t, updatedChannelMember.SchemeGuest)
require.True(t, updatedChannelMember.SchemeUser)
@@ -4891,7 +4891,7 @@ func testUserStorePromoteGuestToUser(t *testing.T, ss store.Store) {
require.False(t, updatedTeamMember.SchemeGuest)
require.True(t, updatedTeamMember.SchemeUser)
updatedChannelMember, nErr := ss.Channel().GetMember(channel.Id, user.Id)
updatedChannelMember, nErr := ss.Channel().GetMember(context.Background(), channel.Id, user.Id)
require.NoError(t, nErr)
require.False(t, updatedChannelMember.SchemeGuest)
require.True(t, updatedChannelMember.SchemeUser)
@@ -4987,7 +4987,7 @@ func testUserStorePromoteGuestToUser(t *testing.T, ss store.Store) {
require.False(t, updatedTeamMember.SchemeGuest)
require.True(t, updatedTeamMember.SchemeUser)
updatedChannelMember, nErr := ss.Channel().GetMember(channel.Id, user.Id)
updatedChannelMember, nErr := ss.Channel().GetMember(context.Background(), channel.Id, user.Id)
require.NoError(t, nErr)
require.False(t, updatedChannelMember.SchemeGuest)
require.True(t, updatedChannelMember.SchemeUser)
@@ -5032,7 +5032,7 @@ func testUserStorePromoteGuestToUser(t *testing.T, ss store.Store) {
require.False(t, updatedTeamMember.SchemeGuest)
require.True(t, updatedTeamMember.SchemeUser)
updatedChannelMember, nErr := ss.Channel().GetMember(channel.Id, user.Id)
updatedChannelMember, nErr := ss.Channel().GetMember(context.Background(), channel.Id, user.Id)
require.NoError(t, nErr)
require.False(t, updatedChannelMember.SchemeGuest)
require.True(t, updatedChannelMember.SchemeUser)
@@ -5098,7 +5098,7 @@ func testUserStorePromoteGuestToUser(t *testing.T, ss store.Store) {
require.False(t, updatedTeamMember.SchemeGuest)
require.True(t, updatedTeamMember.SchemeUser)
updatedChannelMember, nErr := ss.Channel().GetMember(channel.Id, user1.Id)
updatedChannelMember, nErr := ss.Channel().GetMember(context.Background(), channel.Id, user1.Id)
require.NoError(t, nErr)
require.False(t, updatedChannelMember.SchemeGuest)
require.True(t, updatedChannelMember.SchemeUser)
@@ -5112,7 +5112,7 @@ func testUserStorePromoteGuestToUser(t *testing.T, ss store.Store) {
require.True(t, notUpdatedTeamMember.SchemeGuest)
require.False(t, notUpdatedTeamMember.SchemeUser)
notUpdatedChannelMember, nErr := ss.Channel().GetMember(channel.Id, user2.Id)
notUpdatedChannelMember, nErr := ss.Channel().GetMember(context.Background(), channel.Id, user2.Id)
require.NoError(t, nErr)
require.True(t, notUpdatedChannelMember.SchemeGuest)
require.False(t, notUpdatedChannelMember.SchemeUser)
@@ -5159,7 +5159,7 @@ func testUserStoreDemoteUserToGuest(t *testing.T, ss store.Store) {
require.True(t, updatedTeamMember.SchemeGuest)
require.False(t, updatedTeamMember.SchemeUser)
updatedChannelMember, nErr := ss.Channel().GetMember(channel.Id, updatedUser.Id)
updatedChannelMember, nErr := ss.Channel().GetMember(context.Background(), channel.Id, updatedUser.Id)
require.NoError(t, nErr)
require.True(t, updatedChannelMember.SchemeGuest)
require.False(t, updatedChannelMember.SchemeUser)
@@ -5202,7 +5202,7 @@ func testUserStoreDemoteUserToGuest(t *testing.T, ss store.Store) {
require.True(t, updatedTeamMember.SchemeGuest)
require.False(t, updatedTeamMember.SchemeUser)
updatedChannelMember, nErr := ss.Channel().GetMember(channel.Id, user.Id)
updatedChannelMember, nErr := ss.Channel().GetMember(context.Background(), channel.Id, user.Id)
require.NoError(t, nErr)
require.True(t, updatedChannelMember.SchemeGuest)
require.False(t, updatedChannelMember.SchemeUser)
@@ -5292,7 +5292,7 @@ func testUserStoreDemoteUserToGuest(t *testing.T, ss store.Store) {
require.True(t, updatedTeamMember.SchemeGuest)
require.False(t, updatedTeamMember.SchemeUser)
updatedChannelMember, nErr := ss.Channel().GetMember(channel.Id, user.Id)
updatedChannelMember, nErr := ss.Channel().GetMember(context.Background(), channel.Id, user.Id)
require.NoError(t, nErr)
require.True(t, updatedChannelMember.SchemeGuest)
require.False(t, updatedChannelMember.SchemeUser)
@@ -5335,7 +5335,7 @@ func testUserStoreDemoteUserToGuest(t *testing.T, ss store.Store) {
require.True(t, updatedTeamMember.SchemeGuest)
require.False(t, updatedTeamMember.SchemeUser)
updatedChannelMember, nErr := ss.Channel().GetMember(channel.Id, user.Id)
updatedChannelMember, nErr := ss.Channel().GetMember(context.Background(), channel.Id, user.Id)
require.NoError(t, nErr)
require.True(t, updatedChannelMember.SchemeGuest)
require.False(t, updatedChannelMember.SchemeUser)
@@ -5399,7 +5399,7 @@ func testUserStoreDemoteUserToGuest(t *testing.T, ss store.Store) {
require.True(t, updatedTeamMember.SchemeGuest)
require.False(t, updatedTeamMember.SchemeUser)
updatedChannelMember, nErr := ss.Channel().GetMember(channel.Id, user1.Id)
updatedChannelMember, nErr := ss.Channel().GetMember(context.Background(), channel.Id, user1.Id)
require.NoError(t, nErr)
require.True(t, updatedChannelMember.SchemeGuest)
require.False(t, updatedChannelMember.SchemeUser)
@@ -5413,7 +5413,7 @@ func testUserStoreDemoteUserToGuest(t *testing.T, ss store.Store) {
require.False(t, notUpdatedTeamMember.SchemeGuest)
require.True(t, notUpdatedTeamMember.SchemeUser)
notUpdatedChannelMember, nErr := ss.Channel().GetMember(channel.Id, user2.Id)
notUpdatedChannelMember, nErr := ss.Channel().GetMember(context.Background(), channel.Id, user2.Id)
require.NoError(t, nErr)
require.False(t, notUpdatedChannelMember.SchemeGuest)
require.True(t, notUpdatedChannelMember.SchemeUser)

View File

@@ -1095,10 +1095,10 @@ func (s *TimerLayerChannelStore) GetGuestCount(channelID string, allowFromCache
return result, err
}
func (s *TimerLayerChannelStore) GetMember(channelID string, userId string) (*model.ChannelMember, error) {
func (s *TimerLayerChannelStore) GetMember(ctx context.Context, channelID string, userId string) (*model.ChannelMember, error) {
start := timemodule.Now()
result, err := s.ChannelStore.GetMember(channelID, userId)
result, err := s.ChannelStore.GetMember(ctx, channelID, userId)
elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second)
if s.Root.Metrics != nil {
@@ -1143,10 +1143,10 @@ func (s *TimerLayerChannelStore) GetMemberCountFromCache(channelID string) int64
return result
}
func (s *TimerLayerChannelStore) GetMemberCountsByGroup(channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, error) {
func (s *TimerLayerChannelStore) GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, error) {
start := timemodule.Now()
result, err := s.ChannelStore.GetMemberCountsByGroup(channelID, includeTimezones)
result, err := s.ChannelStore.GetMemberCountsByGroup(ctx, channelID, includeTimezones)
elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second)
if s.Root.Metrics != nil {