mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
[MM-56174] Account for archived channels in channel member for post permission check (#25837)
* [MM-56174] Account for archived channels in channel member for post permission check * Add tests
This commit is contained in:
@@ -172,7 +172,7 @@ func (a *App) SessionHasPermissionToGroup(session model.Session, groupID string,
|
||||
}
|
||||
|
||||
func (a *App) SessionHasPermissionToChannelByPost(session model.Session, postID string, permission *model.Permission) bool {
|
||||
if channelMember, err := a.Srv().Store().Channel().GetMemberForPost(postID, session.UserId); err == nil {
|
||||
if channelMember, err := a.Srv().Store().Channel().GetMemberForPost(postID, session.UserId, *a.Config().TeamSettings.ExperimentalViewArchivedChannels); err == nil {
|
||||
if a.RolesGrantPermission(channelMember.GetRoles(), permission.Id) {
|
||||
return true
|
||||
}
|
||||
@@ -278,7 +278,7 @@ func (a *App) HasPermissionToChannel(c request.CTX, askingUserId string, channel
|
||||
}
|
||||
|
||||
func (a *App) HasPermissionToChannelByPost(c request.CTX, askingUserId string, postID string, permission *model.Permission) bool {
|
||||
if channelMember, err := a.Srv().Store().Channel().GetMemberForPost(postID, askingUserId); err == nil {
|
||||
if channelMember, err := a.Srv().Store().Channel().GetMemberForPost(postID, askingUserId, *a.Config().TeamSettings.ExperimentalViewArchivedChannels); err == nil {
|
||||
if a.RolesGrantPermission(channelMember.GetRoles(), permission.Id) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -667,3 +667,107 @@ func TestHasPermissionToReadChannel(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionHasPermissionToChannelByPost(t *testing.T) {
|
||||
th := Setup(t).InitBasic()
|
||||
defer th.TearDown()
|
||||
|
||||
session, err := th.App.CreateSession(th.Context, &model.Session{
|
||||
UserId: th.BasicUser.Id,
|
||||
Roles: model.SystemUserRoleId,
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
session2, err := th.App.CreateSession(th.Context, &model.Session{
|
||||
UserId: th.BasicUser2.Id,
|
||||
Roles: model.SystemUserRoleId,
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
channel := th.CreateChannel(th.Context, th.BasicTeam)
|
||||
th.App.AddUserToChannel(th.Context, th.BasicUser, channel, false)
|
||||
post := th.CreatePost(channel)
|
||||
|
||||
archivedChannel := th.CreateChannel(th.Context, th.BasicTeam)
|
||||
archivedPost := th.CreatePost(archivedChannel)
|
||||
th.App.DeleteChannel(th.Context, archivedChannel, th.SystemAdminUser.Id)
|
||||
|
||||
t.Run("read channel", func(t *testing.T) {
|
||||
require.Equal(t, true, th.App.SessionHasPermissionToChannelByPost(*session, post.Id, model.PermissionReadChannel))
|
||||
require.Equal(t, false, th.App.SessionHasPermissionToChannelByPost(*session2, post.Id, model.PermissionReadChannel))
|
||||
})
|
||||
|
||||
t.Run("read archived channel - setting off", func(t *testing.T) {
|
||||
th.App.UpdateConfig(func(cfg *model.Config) {
|
||||
cfg.TeamSettings.ExperimentalViewArchivedChannels = model.NewBool(false)
|
||||
})
|
||||
require.Equal(t, false, th.App.SessionHasPermissionToChannelByPost(*session, archivedPost.Id, model.PermissionReadChannel))
|
||||
require.Equal(t, false, th.App.SessionHasPermissionToChannelByPost(*session2, archivedPost.Id, model.PermissionReadChannel))
|
||||
})
|
||||
|
||||
t.Run("read archived channel - setting on", func(t *testing.T) {
|
||||
th.App.UpdateConfig(func(cfg *model.Config) {
|
||||
cfg.TeamSettings.ExperimentalViewArchivedChannels = model.NewBool(true)
|
||||
})
|
||||
require.Equal(t, true, th.App.SessionHasPermissionToChannelByPost(*session, archivedPost.Id, model.PermissionReadChannel))
|
||||
require.Equal(t, false, th.App.SessionHasPermissionToChannelByPost(*session2, archivedPost.Id, model.PermissionReadChannel))
|
||||
})
|
||||
|
||||
t.Run("read public channel", func(t *testing.T) {
|
||||
require.Equal(t, true, th.App.SessionHasPermissionToChannelByPost(*session, post.Id, model.PermissionReadPublicChannel))
|
||||
require.Equal(t, true, th.App.SessionHasPermissionToChannelByPost(*session2, post.Id, model.PermissionReadPublicChannel))
|
||||
})
|
||||
|
||||
t.Run("read channel - user is admin", func(t *testing.T) {
|
||||
adminSession, err := th.App.CreateSession(th.Context, &model.Session{
|
||||
UserId: th.SystemAdminUser.Id,
|
||||
Roles: model.SystemAdminRoleId,
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, true, th.App.SessionHasPermissionToChannelByPost(*adminSession, post.Id, model.PermissionReadChannel))
|
||||
})
|
||||
}
|
||||
|
||||
func TestHasPermissionToChannelByPost(t *testing.T) {
|
||||
th := Setup(t).InitBasic()
|
||||
defer th.TearDown()
|
||||
|
||||
channel := th.CreateChannel(th.Context, th.BasicTeam)
|
||||
th.App.AddUserToChannel(th.Context, th.BasicUser, channel, false)
|
||||
post := th.CreatePost(channel)
|
||||
|
||||
archivedChannel := th.CreateChannel(th.Context, th.BasicTeam)
|
||||
archivedPost := th.CreatePost(archivedChannel)
|
||||
th.App.DeleteChannel(th.Context, archivedChannel, th.SystemAdminUser.Id)
|
||||
|
||||
t.Run("read channel", func(t *testing.T) {
|
||||
require.Equal(t, true, th.App.HasPermissionToChannelByPost(th.Context, th.BasicUser.Id, post.Id, model.PermissionReadChannel))
|
||||
require.Equal(t, false, th.App.HasPermissionToChannelByPost(th.Context, th.BasicUser2.Id, post.Id, model.PermissionReadChannel))
|
||||
})
|
||||
|
||||
t.Run("read archived channel - setting off", func(t *testing.T) {
|
||||
th.App.UpdateConfig(func(cfg *model.Config) {
|
||||
cfg.TeamSettings.ExperimentalViewArchivedChannels = model.NewBool(false)
|
||||
})
|
||||
require.Equal(t, false, th.App.HasPermissionToChannelByPost(th.Context, th.BasicUser.Id, archivedPost.Id, model.PermissionReadChannel))
|
||||
require.Equal(t, false, th.App.HasPermissionToChannelByPost(th.Context, th.BasicUser2.Id, archivedPost.Id, model.PermissionReadChannel))
|
||||
})
|
||||
|
||||
t.Run("read archived channel - setting on", func(t *testing.T) {
|
||||
th.App.UpdateConfig(func(cfg *model.Config) {
|
||||
cfg.TeamSettings.ExperimentalViewArchivedChannels = model.NewBool(true)
|
||||
})
|
||||
require.Equal(t, true, th.App.HasPermissionToChannelByPost(th.Context, th.BasicUser.Id, archivedPost.Id, model.PermissionReadChannel))
|
||||
require.Equal(t, false, th.App.HasPermissionToChannelByPost(th.Context, th.BasicUser2.Id, archivedPost.Id, model.PermissionReadChannel))
|
||||
})
|
||||
|
||||
t.Run("read public channel", func(t *testing.T) {
|
||||
require.Equal(t, true, th.App.HasPermissionToChannelByPost(th.Context, th.BasicUser.Id, post.Id, model.PermissionReadPublicChannel))
|
||||
require.Equal(t, true, th.App.HasPermissionToChannelByPost(th.Context, th.BasicUser2.Id, post.Id, model.PermissionReadPublicChannel))
|
||||
})
|
||||
|
||||
t.Run("read channel - user is admin", func(t *testing.T) {
|
||||
require.Equal(t, true, th.App.HasPermissionToChannelByPost(th.Context, th.SystemAdminUser.Id, post.Id, model.PermissionReadChannel))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1554,7 +1554,7 @@ func (s *OpenTracingLayerChannelStore) GetMemberCountsByGroup(ctx context.Contex
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *OpenTracingLayerChannelStore) GetMemberForPost(postID string, userID string) (*model.ChannelMember, error) {
|
||||
func (s *OpenTracingLayerChannelStore) GetMemberForPost(postID string, userID string, includeArchivedChannels bool) (*model.ChannelMember, error) {
|
||||
origCtx := s.Root.Store.Context()
|
||||
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "ChannelStore.GetMemberForPost")
|
||||
s.Root.Store.SetContext(newCtx)
|
||||
@@ -1563,7 +1563,7 @@ func (s *OpenTracingLayerChannelStore) GetMemberForPost(postID string, userID st
|
||||
}()
|
||||
|
||||
defer span.Finish()
|
||||
result, err := s.ChannelStore.GetMemberForPost(postID, userID)
|
||||
result, err := s.ChannelStore.GetMemberForPost(postID, userID, includeArchivedChannels)
|
||||
if err != nil {
|
||||
span.LogFields(spanlog.Error(err))
|
||||
ext.Error.Set(span, true)
|
||||
|
||||
@@ -1727,11 +1727,11 @@ func (s *RetryLayerChannelStore) GetMemberCountsByGroup(ctx context.Context, cha
|
||||
|
||||
}
|
||||
|
||||
func (s *RetryLayerChannelStore) GetMemberForPost(postID string, userID string) (*model.ChannelMember, error) {
|
||||
func (s *RetryLayerChannelStore) GetMemberForPost(postID string, userID string, includeArchivedChannels bool) (*model.ChannelMember, error) {
|
||||
|
||||
tries := 0
|
||||
for {
|
||||
result, err := s.ChannelStore.GetMemberForPost(postID, userID)
|
||||
result, err := s.ChannelStore.GetMemberForPost(postID, userID, includeArchivedChannels)
|
||||
if err == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -2107,7 +2107,7 @@ func (s SqlChannelStore) IsUserInChannelUseCache(userId string, channelId string
|
||||
return false
|
||||
}
|
||||
|
||||
func (s SqlChannelStore) GetMemberForPost(postId string, userId string) (*model.ChannelMember, error) {
|
||||
func (s SqlChannelStore) GetMemberForPost(postId string, userId string, includeArchivedChannels bool) (*model.ChannelMember, error) {
|
||||
var dbMember channelMemberWithSchemeRoles
|
||||
query := `
|
||||
SELECT
|
||||
@@ -2147,6 +2147,10 @@ func (s SqlChannelStore) GetMemberForPost(postId string, userId string) (*model.
|
||||
ChannelMembers.UserId = ?
|
||||
AND
|
||||
Posts.Id = ?`
|
||||
|
||||
if !includeArchivedChannels {
|
||||
query += " AND Channels.DeleteAt = 0"
|
||||
}
|
||||
if err := s.GetReplicaX().Get(&dbMember, query, userId, postId); err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to get ChannelMember with postId=%s and userId=%s", postId, userId)
|
||||
}
|
||||
|
||||
@@ -234,7 +234,7 @@ type ChannelStore interface {
|
||||
IsUserInChannelUseCache(userID string, channelID string) bool
|
||||
GetAllChannelMembersNotifyPropsForChannel(channelID string, allowFromCache bool) (map[string]model.StringMap, error)
|
||||
InvalidateCacheForChannelMembersNotifyProps(channelID string)
|
||||
GetMemberForPost(postID string, userID string) (*model.ChannelMember, error)
|
||||
GetMemberForPost(postID string, userID string, includeArchivedChannels bool) (*model.ChannelMember, error)
|
||||
InvalidateMemberCount(channelID string)
|
||||
GetMemberCountFromCache(channelID string) int64
|
||||
GetFileCount(channelID string) (int64, error)
|
||||
|
||||
@@ -4917,11 +4917,11 @@ func testChannelStoreGetMemberForPost(t *testing.T, rctx request.CTX, ss store.S
|
||||
})
|
||||
require.NoError(t, nErr)
|
||||
|
||||
r1, err := ss.Channel().GetMemberForPost(p1.Id, m1.UserId)
|
||||
r1, err := ss.Channel().GetMemberForPost(p1.Id, m1.UserId, false)
|
||||
require.NoError(t, err, err)
|
||||
require.Equal(t, channelMemberToJSON(t, m1), channelMemberToJSON(t, r1), "invalid returned channel member")
|
||||
|
||||
_, err = ss.Channel().GetMemberForPost(p1.Id, model.NewId())
|
||||
_, err = ss.Channel().GetMemberForPost(p1.Id, model.NewId(), false)
|
||||
require.Error(t, err, "shouldn't have returned a member")
|
||||
}
|
||||
|
||||
|
||||
@@ -1288,25 +1288,25 @@ func (_m *ChannelStore) GetMemberCountsByGroup(ctx context.Context, channelID st
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// GetMemberForPost provides a mock function with given fields: postID, userID
|
||||
func (_m *ChannelStore) GetMemberForPost(postID string, userID string) (*model.ChannelMember, error) {
|
||||
ret := _m.Called(postID, userID)
|
||||
// GetMemberForPost provides a mock function with given fields: postID, userID, includeArchivedChannels
|
||||
func (_m *ChannelStore) GetMemberForPost(postID string, userID string, includeArchivedChannels bool) (*model.ChannelMember, error) {
|
||||
ret := _m.Called(postID, userID, includeArchivedChannels)
|
||||
|
||||
var r0 *model.ChannelMember
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(string, string) (*model.ChannelMember, error)); ok {
|
||||
return rf(postID, userID)
|
||||
if rf, ok := ret.Get(0).(func(string, string, bool) (*model.ChannelMember, error)); ok {
|
||||
return rf(postID, userID, includeArchivedChannels)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(string, string) *model.ChannelMember); ok {
|
||||
r0 = rf(postID, userID)
|
||||
if rf, ok := ret.Get(0).(func(string, string, bool) *model.ChannelMember); ok {
|
||||
r0 = rf(postID, userID, includeArchivedChannels)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*model.ChannelMember)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(string, string) error); ok {
|
||||
r1 = rf(postID, userID)
|
||||
if rf, ok := ret.Get(1).(func(string, string, bool) error); ok {
|
||||
r1 = rf(postID, userID, includeArchivedChannels)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
@@ -1445,10 +1445,10 @@ func (s *TimerLayerChannelStore) GetMemberCountsByGroup(ctx context.Context, cha
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *TimerLayerChannelStore) GetMemberForPost(postID string, userID string) (*model.ChannelMember, error) {
|
||||
func (s *TimerLayerChannelStore) GetMemberForPost(postID string, userID string, includeArchivedChannels bool) (*model.ChannelMember, error) {
|
||||
start := time.Now()
|
||||
|
||||
result, err := s.ChannelStore.GetMemberForPost(postID, userID)
|
||||
result, err := s.ChannelStore.GetMemberForPost(postID, userID, includeArchivedChannels)
|
||||
|
||||
elapsed := float64(time.Since(start)) / float64(time.Second)
|
||||
if s.Root.Metrics != nil {
|
||||
|
||||
Reference in New Issue
Block a user