Fix racy test issues (#24971)

This commit is contained in:
Ben Schumacher 2023-11-06 12:26:17 +01:00 committed by GitHub
parent 366d1613b7
commit 486e836b83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 191 additions and 172 deletions

View File

@ -140,13 +140,12 @@ func setupTestHelper(dbStore store.Store, searchEngine *searchengine.Broker, ent
th := &TestHelper{
App: app.New(app.ServerConnector(s.Channels())),
Server: s,
Context: request.EmptyContext(testLogger),
ConfigStore: configStore,
IncludeCacheLayer: includeCache,
Context: request.EmptyContext(testLogger),
TestLogger: testLogger,
LogBuffer: buffer,
}
th.Context.SetLogger(testLogger)
if s.Platform().SearchEngine != nil && s.Platform().SearchEngine.BleveEngine != nil && searchEngine != nil {
searchEngine.BleveEngine = s.Platform().SearchEngine.BleveEngine

View File

@ -4,7 +4,6 @@
package api4
import (
"context"
"encoding/json"
"net/http"
"strconv"
@ -1460,9 +1459,8 @@ func getChannelMember(c *Context, w http.ResponseWriter, r *http.Request) {
return
}
ctx := c.AppContext
ctx.SetContext(app.WithMaster(ctx.Context()))
member, err := c.App.GetChannelMember(ctx, c.Params.ChannelId, c.Params.UserId)
c.AppContext = c.AppContext.With(app.RequestContextWithMaster)
member, err := c.App.GetChannelMember(c.AppContext, c.Params.ChannelId, c.Params.UserId)
if err != nil {
c.Err = err
return
@ -2004,7 +2002,7 @@ func channelMemberCountsByGroup(c *Context, w http.ResponseWriter, r *http.Reque
includeTimezones := r.URL.Query().Get("include_timezones") == "true"
channelMemberCounts, appErr := c.App.GetMemberCountsByGroup(app.WithMaster(context.Background()), c.Params.ChannelId, includeTimezones)
channelMemberCounts, appErr := c.App.GetMemberCountsByGroup(c.AppContext.With(app.RequestContextWithMaster), c.Params.ChannelId, includeTimezones)
if appErr != nil {
c.Err = appErr
return

View File

@ -195,7 +195,7 @@ func uploadRemoteData(c *Context, w http.ResponseWriter, r *http.Request) {
defer c.LogAuditRec(auditRec)
audit.AddEventParameter(auditRec, "upload_id", c.Params.UploadId)
c.AppContext.SetContext(app.WithMaster(c.AppContext.Context()))
c.AppContext = c.AppContext.With(app.RequestContextWithMaster)
us, err := c.App.GetUploadSession(c.AppContext, c.Params.UploadId)
if err != nil {
c.Err = err

View File

@ -123,7 +123,7 @@ func uploadData(c *Context, w http.ResponseWriter, r *http.Request) {
defer c.LogAuditRec(auditRec)
audit.AddEventParameter(auditRec, "upload_id", c.Params.UploadId)
c.AppContext.SetContext(app.WithMaster(c.AppContext.Context()))
c.AppContext = c.AppContext.With(app.RequestContextWithMaster)
us, err := c.App.GetUploadSession(c.AppContext, c.Params.UploadId)
if err != nil {
c.Err = err

View File

@ -1953,11 +1953,12 @@ func login(c *Context, w http.ResponseWriter, r *http.Request) {
c.LogAuditWithUserId(user.Id, "authenticated")
isMobileDevice := utils.IsMobileRequest(r)
err = c.App.DoLogin(c.AppContext, w, r, user, deviceId, isMobileDevice, false, false)
session, err := c.App.DoLogin(c.AppContext, w, r, user, deviceId, isMobileDevice, false, false)
if err != nil {
c.Err = err
return
}
c.AppContext = c.AppContext.WithSession(session)
c.LogAuditWithUserId(user.Id, "success")
@ -1995,11 +1996,12 @@ func loginWithDesktopToken(c *Context, w http.ResponseWriter, r *http.Request) {
return
}
err = c.App.DoLogin(c.AppContext, w, r, user, deviceId, false, false, false)
session, err := c.App.DoLogin(c.AppContext, w, r, user, deviceId, false, false, false)
if err != nil {
c.Err = err
return
}
c.AppContext = c.AppContext.WithSession(session)
c.App.AttachSessionCookies(c.AppContext, w, r)
@ -2048,12 +2050,13 @@ func loginCWS(c *Context, w http.ResponseWriter, r *http.Request) {
audit.AddEventParameterAuditable(auditRec, "user", user)
c.LogAuditWithUserId(user.Id, "authenticated")
isMobileDevice := utils.IsMobileRequest(r)
err = c.App.DoLogin(c.AppContext, w, r, user, "", isMobileDevice, false, false)
session, err := c.App.DoLogin(c.AppContext, w, r, user, "", isMobileDevice, false, false)
if err != nil {
c.LogErrorByCode(err)
http.Redirect(w, r, *c.App.Config().ServiceSettings.SiteURL, http.StatusFound)
return
}
c.AppContext = c.AppContext.WithSession(session)
c.LogAuditWithUserId(user.Id, "success")
c.App.AttachSessionCookies(c.AppContext, w, r)

View File

@ -553,7 +553,7 @@ type AppIface interface {
DoEmojisPermissionsMigration()
DoGuestRolesCreationMigration()
DoLocalRequest(c request.CTX, rawURL string, body []byte) (*http.Response, *model.AppError)
DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile, isOAuthUser, isSaml bool) *model.AppError
DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile, isOAuthUser, isSaml bool) (*model.Session, *model.AppError)
DoPostAction(c request.CTX, postID, actionId, userID, selectedOption string) (string, *model.AppError)
DoPostActionWithCookie(c request.CTX, postID, actionId, userID, selectedOption string, cookie *model.PostActionCookie) (string, *model.AppError)
DoSystemConsoleRolesCreationMigration()
@ -687,7 +687,7 @@ type AppIface interface {
GetLatestVersion(latestVersionUrl string) (*model.GithubReleaseInfo, *model.AppError)
GetLogs(c request.CTX, page, perPage int) ([]string, *model.AppError)
GetLogsSkipSend(page, perPage int, logFilter *model.LogFilter) ([]string, *model.AppError)
GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError)
GetMemberCountsByGroup(rctx request.CTX, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError)
GetMessageForNotification(post *model.Post, translateFunc i18n.TranslateFunc) string
GetMultipleEmojiByName(c request.CTX, names []string) ([]*model.Emoji, *model.AppError)
GetNewUsersForTeamPage(teamID string, page, perPage int, asAdmin bool, viewRestrictions *model.ViewUsersRestrictions) ([]*model.User, *model.AppError)

View File

@ -1666,7 +1666,9 @@ func (a *App) AddChannelMember(c request.CTX, userID string, channel *model.Chan
}
} else {
a.Srv().Go(func() {
a.PostAddToChannelMessage(c, userRequestor, user, channel, opts.PostRootID)
if err := a.PostAddToChannelMessage(c, userRequestor, user, channel, opts.PostRootID); err != nil {
c.Logger().Error("Failed to post AddToChannel message", mlog.Err(err))
}
})
}
@ -2363,7 +2365,9 @@ func (a *App) LeaveChannel(c request.CTX, channelID string, userID string) *mode
}
a.Srv().Go(func() {
a.postLeaveChannelMessage(c, user, channel)
if err := a.postLeaveChannelMessage(c, user, channel); err != nil {
c.Logger().Error("Failed to post LeaveChannel message", mlog.Err(err))
}
})
return nil
@ -3458,8 +3462,8 @@ func (a *App) ClearChannelMembersCache(c request.CTX, channelID string) error {
return nil
}
func (a *App) GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) {
channelMemberCounts, err := a.Srv().Store().Channel().GetMemberCountsByGroup(ctx, channelID, includeTimezones)
func (a *App) GetMemberCountsByGroup(rctx request.CTX, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) {
channelMemberCounts, err := a.Srv().Store().Channel().GetMemberCountsByGroup(rctx.Context(), channelID, includeTimezones)
if err != nil {
return nil, model.NewAppError("GetMemberCountsByGroup", "app.channel.get_member_count.app_error", nil, "", http.StatusInternalServerError).Wrap(err)
}

View File

@ -2152,7 +2152,7 @@ func TestGetMemberCountsByGroup(t *testing.T) {
mockChannelStore.On("GetMemberCountsByGroup", context.Background(), "channelID", true).Return(cmc, nil)
mockStore.On("Channel").Return(&mockChannelStore)
mockStore.On("GetDBSchemaVersion").Return(1, nil)
resp, err := th.App.GetMemberCountsByGroup(context.Background(), "channelID", true)
resp, err := th.App.GetMemberCountsByGroup(th.Context, "channelID", true)
require.Nil(t, err)
require.ElementsMatch(t, cmc, resp)
}

View File

@ -34,7 +34,7 @@ func (a *App) SaveComplianceReport(rctx request.CTX, job *model.Compliance) (*mo
job.Type = model.ComplianceTypeAdhoc
rctx.SetLogger(rctx.Logger().With(job.LoggerFields()...))
rctx = rctx.WithLogger(rctx.Logger().With(job.LoggerFields()...))
job, err := a.Srv().Store().Compliance().Save(job)
if err != nil {
@ -48,11 +48,10 @@ func (a *App) SaveComplianceReport(rctx request.CTX, job *model.Compliance) (*mo
}
jCopy := job.DeepCopy()
crctx := rctx.Clone()
a.Srv().Go(func() {
err := a.Compliance().RunComplianceJob(crctx, jCopy)
err := a.Compliance().RunComplianceJob(rctx, jCopy)
if err != nil {
crctx.Logger().Warn("Error running compliance job", mlog.Err(err))
rctx.Logger().Warn("Error running compliance job", mlog.Err(err))
}
})

View File

@ -4,16 +4,14 @@
package app
import (
"context"
"github.com/mattermost/mattermost/server/public/plugin"
"github.com/mattermost/mattermost/server/public/shared/request"
"github.com/mattermost/mattermost/server/v8/channels/store/sqlstore"
)
// WithMaster adds the context value that master DB should be selected for this request.
func WithMaster(ctx context.Context) context.Context {
return sqlstore.WithMaster(ctx)
// RequestContextWithMaster adds the context value that master DB should be selected for this request.
func RequestContextWithMaster(c request.CTX) request.CTX {
return sqlstore.RequestContextWithMaster(c)
}
func pluginContext(c request.CTX) *plugin.Context {

View File

@ -796,11 +796,10 @@ func (a *App) UploadFileX(c request.CTX, channelID, name string, input io.Reader
if *a.Config().FileSettings.ExtractContent {
infoCopy := *t.fileinfo
crctx := c.Clone()
a.Srv().GoBuffered(func() {
err := a.ExtractContentFromFileInfo(crctx, &infoCopy)
err := a.ExtractContentFromFileInfo(c, &infoCopy)
if err != nil {
crctx.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id))
c.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id))
}
})
}
@ -1048,11 +1047,10 @@ func (a *App) DoUploadFileExpectModification(c request.CTX, now time.Time, rawTe
if *a.Config().FileSettings.ExtractContent {
infoCopy := *info
crctx := c.Clone()
a.Srv().GoBuffered(func() {
err := a.ExtractContentFromFileInfo(crctx, &infoCopy)
err := a.ExtractContentFromFileInfo(c, &infoCopy)
if err != nil {
crctx.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id))
c.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id))
}
})
}

View File

@ -99,7 +99,6 @@ func setupTestHelper(dbStore store.Store, enterprise bool, includeCacheLayer boo
IncludeCacheLayer: includeCacheLayer,
ConfigStore: configStore,
}
th.Context.SetLogger(testLogger)
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.TeamSettings.MaxUsersPerTeam = 50 })
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.RateLimitSettings.Enable = false })

View File

@ -268,7 +268,7 @@ func (a *App) bulkImport(c request.CTX, jsonlReader io.Reader, attachmentsReader
linesChan = make(chan imports.LineImportWorkerData, workers)
for i := 0; i < workers; i++ {
wg.Add(1)
go a.bulkImportWorker(c.Clone(), dryRun, &wg, linesChan, errorsChan)
go a.bulkImportWorker(c, dryRun, &wg, linesChan, errorsChan)
}
}

View File

@ -156,7 +156,7 @@ func (a *App) GetUserForLogin(c request.CTX, id, loginId string) (*model.User, *
return nil, model.NewAppError("GetUserForLogin", "store.sql_user.get_for_login.app_error", nil, "", http.StatusBadRequest)
}
func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile, isOAuthUser, isSaml bool) *model.AppError {
func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile, isOAuthUser, isSaml bool) (*model.Session, *model.AppError) {
var rejectionReason string
pluginContext := pluginContext(c)
a.ch.RunMultiHook(func(hooks plugin.Hooks) bool {
@ -165,7 +165,7 @@ func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, use
}, plugin.UserWillLogInID)
if rejectionReason != "" {
return model.NewAppError("DoLogin", "Login rejected by plugin: "+rejectionReason, nil, "", http.StatusBadRequest)
return nil, model.NewAppError("DoLogin", "Login rejected by plugin: "+rejectionReason, nil, "", http.StatusBadRequest)
}
session := &model.Session{UserId: user.Id, Roles: user.GetRawRoles(), DeviceId: deviceID, IsOAuth: false, Props: map[string]string{
@ -181,7 +181,7 @@ func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, use
// A special case where we logout of all other sessions with the same Id
if err := a.RevokeSessionsForDeviceId(c, user.Id, deviceID, ""); err != nil {
err.StatusCode = http.StatusInternalServerError
return err
return nil, err
}
} else if isMobile {
a.ch.srv.platform.SetSessionExpireInHours(session, *a.Config().ServiceSettings.SessionLengthMobileInHours)
@ -210,12 +210,12 @@ func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, use
var err *model.AppError
if session, err = a.CreateSession(c, session); err != nil {
err.StatusCode = http.StatusInternalServerError
return err
return nil, err
}
w.Header().Set(model.HeaderToken, session.Token)
c.SetSession(session)
c = c.WithSession(session)
if a.Srv().License() != nil && *a.Srv().License().Features.LDAP && a.Ldap() != nil {
userVal := *user
sessionVal := *session
@ -231,7 +231,7 @@ func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, use
}, plugin.UserHasLoggedInID)
})
return nil
return session, nil
}
func (a *App) AttachCloudSessionCookie(c request.CTX, w http.ResponseWriter, r *http.Request) {

View File

@ -3828,7 +3828,7 @@ func (a *OpenTracingAppLayer) DoLocalRequest(c request.CTX, rawURL string, body
return resultVar0, resultVar1
}
func (a *OpenTracingAppLayer) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile bool, isOAuthUser bool, isSaml bool) *model.AppError {
func (a *OpenTracingAppLayer) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile bool, isOAuthUser bool, isSaml bool) (*model.Session, *model.AppError) {
origCtx := a.ctx
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.DoLogin")
@ -3840,14 +3840,14 @@ func (a *OpenTracingAppLayer) DoLogin(c request.CTX, w http.ResponseWriter, r *h
}()
defer span.Finish()
resultVar0 := a.app.DoLogin(c, w, r, user, deviceID, isMobile, isOAuthUser, isSaml)
resultVar0, resultVar1 := a.app.DoLogin(c, w, r, user, deviceID, isMobile, isOAuthUser, isSaml)
if resultVar0 != nil {
span.LogFields(spanlog.Error(resultVar0))
if resultVar1 != nil {
span.LogFields(spanlog.Error(resultVar1))
ext.Error.Set(span, true)
}
return resultVar0
return resultVar0, resultVar1
}
func (a *OpenTracingAppLayer) DoPermissionsMigrations() error {
@ -7320,7 +7320,7 @@ func (a *OpenTracingAppLayer) GetMarketplacePlugins(filter *model.MarketplacePlu
return resultVar0, resultVar1
}
func (a *OpenTracingAppLayer) GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) {
func (a *OpenTracingAppLayer) GetMemberCountsByGroup(rctx request.CTX, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) {
origCtx := a.ctx
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetMemberCountsByGroup")
@ -7332,7 +7332,7 @@ func (a *OpenTracingAppLayer) GetMemberCountsByGroup(ctx context.Context, channe
}()
defer span.Finish()
resultVar0, resultVar1 := a.app.GetMemberCountsByGroup(ctx, channelID, includeTimezones)
resultVar0, resultVar1 := a.app.GetMemberCountsByGroup(rctx, channelID, includeTimezones)
if resultVar1 != nil {
span.LogFields(spanlog.Error(resultVar1))

View File

@ -1263,7 +1263,7 @@ func (api *PluginAPI) UploadData(us *model.UploadSession, rd io.Reader) (*model.
func (api *PluginAPI) GetUploadSession(uploadID string) (*model.UploadSession, error) {
// We want to fetch from master DB to avoid a potential read-after-write on the plugin side.
api.ctx.SetContext(WithMaster(api.ctx.Context()))
api.ctx = api.ctx.With(RequestContextWithMaster)
fi, err := api.app.GetUploadSession(api.ctx, uploadID)
if err != nil {
return nil, err

View File

@ -727,9 +727,10 @@ func TestUserWillLogIn_Blocked(t *testing.T) {
r := &http.Request{}
w := httptest.NewRecorder()
err = th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false)
session, err := th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false)
assert.Contains(t, err.Id, "Login rejected by plugin", "Expected Login rejected by plugin, got %s", err.Id)
assert.Nil(t, session)
}
func TestUserWillLogInIn_Passed(t *testing.T) {
@ -766,10 +767,11 @@ func TestUserWillLogInIn_Passed(t *testing.T) {
r := &http.Request{}
w := httptest.NewRecorder()
err = th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false)
session, err := th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false)
assert.Nil(t, err, "Expected nil, got %s", err)
assert.Equal(t, th.Context.Session().UserId, th.BasicUser.Id)
require.NotNil(t, session)
assert.Equal(t, session.UserId, th.BasicUser.Id)
}
func TestUserHasLoggedIn(t *testing.T) {
@ -807,9 +809,10 @@ func TestUserHasLoggedIn(t *testing.T) {
r := &http.Request{}
w := httptest.NewRecorder()
err = th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false)
session, err := th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false)
assert.Nil(t, err, "Expected nil, got %s", err)
assert.NotNil(t, session)
time.Sleep(2 * time.Second)

View File

@ -532,7 +532,7 @@ func (a *App) handlePostEvents(c request.CTX, post *model.Post, user *model.User
a.Srv().Go(func() {
_, err := a.SendAutoResponseIfNecessary(c, channel, user, post)
if err != nil {
mlog.Error("Failed to send auto response", mlog.String("user_id", user.Id), mlog.String("post_id", post.Id), mlog.Err(err))
c.Logger().Error("Failed to send auto response", mlog.String("user_id", user.Id), mlog.String("post_id", post.Id), mlog.Err(err))
}
})
}
@ -540,7 +540,7 @@ func (a *App) handlePostEvents(c request.CTX, post *model.Post, user *model.User
if triggerWebhooks {
a.Srv().Go(func() {
if err := a.handleWebhookEvents(c, post, team, channel, user); err != nil {
mlog.Error(err.Error())
c.Logger().Error("Failed to handle webhook event", mlog.Err(err))
}
})
}
@ -1378,7 +1378,7 @@ func (a *App) DeletePost(c request.CTX, postID, deleteByID string) (*model.Post,
a.Srv().Go(func() {
if err = a.RemoveNotifications(c, post, channel); err != nil {
a.Log().Error("DeletePost failed to delete notification", mlog.Err(err))
c.Logger().Error("DeletePost failed to delete notification", mlog.Err(err))
}
})

View File

@ -19,6 +19,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/request"
"github.com/mattermost/mattermost/server/v8/channels/app/email"
emailmocks "github.com/mattermost/mattermost/server/v8/channels/app/email/mocks"
"github.com/mattermost/mattermost/server/v8/channels/app/teams"
@ -1027,7 +1028,12 @@ func TestLeaveTeamPanic(t *testing.T) {
mockLicenseStore.On("Get", "").Return(&model.LicenseRecord{}, nil)
mockTeamStore := mocks.TeamStore{}
mockTeamStore.On("GetMember", sqlstore.RequestContextWithMaster(th.Context), "myteam", "userID").Return(&model.TeamMember{TeamId: "myteam", UserId: "userID"}, nil)
mockTeamStore.On("GetMember", mock.AnythingOfType("*request.Context"), "myteam", "userID").Return(&model.TeamMember{TeamId: "myteam", UserId: "userID"}, nil).Run(func(args mock.Arguments) {
c, ok := args[0].(request.CTX)
require.True(t, ok)
sqlstore.HasMaster(c.Context())
})
mockTeamStore.On("UpdateMember", mock.Anything).Return(nil, errors.New("repro error")) // This is the line that triggers the error
mockStore.On("Channel").Return(&mockChannelStore)

View File

@ -203,7 +203,7 @@ func (a *App) UploadData(c request.CTX, us *model.UploadSession, rd io.Reader) (
}()
// fetch the session from store to check for inconsistencies.
c.SetContext(WithMaster(c.Context()))
c = c.With(RequestContextWithMaster)
if storedSession, err := a.GetUploadSession(c, us.Id); err != nil {
return nil, err
} else if us.FileOffset != storedSession.FileOffset {
@ -318,11 +318,10 @@ func (a *App) UploadData(c request.CTX, us *model.UploadSession, rd io.Reader) (
if *a.Config().FileSettings.ExtractContent {
infoCopy := *info
crctx := c.Clone()
a.Srv().Go(func() {
err := a.ExtractContentFromFileInfo(crctx, &infoCopy)
err := a.ExtractContentFromFileInfo(c, &infoCopy)
if err != nil {
crctx.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id))
c.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id))
}
})
}

View File

@ -5,6 +5,7 @@ package jobs
import (
"net/http"
"sync/atomic"
"time"
"github.com/mattermost/mattermost/server/public/model"
@ -33,8 +34,9 @@ type BatchMigrationWorker struct {
store store.Store
app BatchMigrationWorkerAppIFace
stop chan bool
stop chan struct{}
stopped chan bool
closed atomic.Bool
jobs chan model.Job
migrationKey string
@ -49,7 +51,7 @@ func MakeBatchMigrationWorker(jobServer *JobServer, store store.Store, app Batch
logger: jobServer.Logger().With(mlog.String("worker_name", migrationKey)),
store: store,
app: app,
stop: make(chan bool, 1),
stop: make(chan struct{}),
stopped: make(chan bool, 1),
jobs: make(chan model.Job),
migrationKey: migrationKey,
@ -64,7 +66,9 @@ func (worker *BatchMigrationWorker) Run() {
worker.logger.Debug("Worker started")
// We have to re-assign the stop channel again, because
// it might happen that the job was restarted due to a config change.
worker.stop = make(chan bool, 1)
if worker.closed.CompareAndSwap(true, false) {
worker.stop = make(chan struct{})
}
defer func() {
worker.logger.Debug("Worker finished")
@ -84,6 +88,11 @@ func (worker *BatchMigrationWorker) Run() {
// Stop interrupts the worker even if the migration has not yet completed.
func (worker *BatchMigrationWorker) Stop() {
// Set to close, and if already closed before, then return.
if !worker.closed.CompareAndSwap(false, true) {
return
}
worker.logger.Debug("Worker stopping")
close(worker.stop)
<-worker.stopped

View File

@ -88,7 +88,6 @@ func setupTestHelper(dbStore store.Store, enterprise bool, includeCacheLayer boo
IncludeCacheLayer: includeCacheLayer,
ConfigStore: configStore,
}
th.Context.SetLogger(testLogger)
prevListenAddress := *th.App.Config().ServiceSettings.ListenAddress
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.ListenAddress = "localhost:0" })

View File

@ -100,10 +100,10 @@ func (worker *Worker) DoJob(job *model.Job) {
return
}
cancelContext := request.EmptyContext(worker.logger)
var cancelContext request.CTX = request.EmptyContext(worker.logger)
cancelCtx, cancelCancelWatcher := context.WithCancel(context.Background())
cancelWatcherChan := make(chan struct{}, 1)
cancelContext.SetContext(cancelCtx)
cancelContext = cancelContext.WithContext(cancelCtx)
go worker.jobServer.CancellationWatcher(cancelContext, job.Id, cancelWatcherChan)
defer cancelCancelWatcher()

View File

@ -29,7 +29,7 @@ type S3PathMigrationWorker struct {
store store.Store
fileBackend *filestore.S3FileBackend
stop chan bool
stop chan struct{}
stopped chan bool
jobs chan model.Job
}
@ -45,7 +45,7 @@ func MakeWorker(jobServer *jobs.JobServer, store store.Store, fileBackend filest
logger: jobServer.Logger().With(mlog.String("worker_name", workerName)),
store: store,
fileBackend: s3Backend,
stop: make(chan bool, 1),
stop: make(chan struct{}),
stopped: make(chan bool, 1),
jobs: make(chan model.Job),
}
@ -56,7 +56,7 @@ func (worker *S3PathMigrationWorker) Run() {
worker.logger.Debug("Worker started")
// We have to re-assign the stop channel again, because
// it might happen that the job was restarted due to a config change.
worker.stop = make(chan bool, 1)
worker.stop = make(chan struct{}, 1)
defer func() {
worker.logger.Debug("Worker finished")

View File

@ -31,7 +31,7 @@ func WithMaster(ctx context.Context) context.Context {
// RequestContextWithMaster adds the context value that master DB should be selected for this request.
func RequestContextWithMaster(c request.CTX) request.CTX {
ctx := WithMaster(c.Context())
c.SetContext(ctx)
c = c.WithContext(ctx)
return c
}

View File

@ -26,39 +26,30 @@ func TestRequestContextWithMaster(t *testing.T) {
assert.True(t, HasMaster(rctx.Context()))
})
t.Run("directly assigning does cause the child to alter the parent", func(t *testing.T) {
var rctx request.CTX = request.TestContext(t)
rctxClone := rctx
rctxClone = RequestContextWithMaster(rctxClone)
assert.True(t, HasMaster(rctx.Context()))
assert.True(t, HasMaster(rctxClone.Context()))
})
t.Run("values get copied from parent", func(t *testing.T) {
t.Run("values get copied from original context", func(t *testing.T) {
var rctx request.CTX = request.TestContext(t)
rctx = RequestContextWithMaster(rctx)
rctxClone := rctx.Clone()
rctxCopy := rctx
assert.True(t, HasMaster(rctx.Context()))
assert.True(t, HasMaster(rctxClone.Context()))
assert.True(t, HasMaster(rctxCopy.Context()))
})
t.Run("changing the child does not alter the parent", func(t *testing.T) {
t.Run("directly assigning does not cause the copy to alter the original context", func(t *testing.T) {
var rctx request.CTX = request.TestContext(t)
rctxClone := rctx.Clone()
rctxClone = RequestContextWithMaster(rctxClone)
rctxCopy := rctx
rctxCopy = RequestContextWithMaster(rctxCopy)
assert.False(t, HasMaster(rctx.Context()))
assert.True(t, HasMaster(rctxClone.Context()))
assert.True(t, HasMaster(rctxCopy.Context()))
})
t.Run("changing the parent does not alter the child", func(t *testing.T) {
t.Run("directly assigning does not cause the original context to alter the copy", func(t *testing.T) {
var rctx request.CTX = request.TestContext(t)
rctxClone := rctx.Clone()
rctxCopy := rctx
rctx = RequestContextWithMaster(rctx)
assert.True(t, HasMaster(rctx.Context()))
assert.False(t, HasMaster(rctxClone.Context()))
assert.False(t, HasMaster(rctxCopy.Context()))
})
}

View File

@ -20,7 +20,7 @@ import (
type Context struct {
App app.AppIface
AppContext *request.Context
AppContext request.CTX
Logger *mlog.Logger
Params *Params
Err *model.AppError

View File

@ -71,7 +71,7 @@ func TestMfaRequired(t *testing.T) {
th.App.Srv().SetLicense(model.NewTestLicense("mfa"))
th.Context.SetSession(&model.Session{Id: "abc", UserId: "userid"})
th.Context = th.Context.WithSession(&model.Session{Id: "abc", UserId: "userid"})
th.App.UpdateConfig(func(cfg *model.Config) {
*cfg.AnnouncementSettings.UserNoticesEnabled = false

View File

@ -205,7 +205,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
span.Finish()
}()
c.AppContext.SetContext(ctx)
c.AppContext = c.AppContext.WithContext(ctx)
tmpSrv := *c.App.Srv()
tmpSrv.SetStore(opentracinglayer.New(c.App.Srv().Store(), ctx))
@ -285,7 +285,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} else if !session.IsOAuth && tokenLocation == app.TokenLocationQueryString {
c.Err = model.NewAppError("ServeHTTP", "api.context.token_provided.app_error", nil, "token="+token, http.StatusUnauthorized)
} else {
c.AppContext.SetSession(session)
c.AppContext = c.AppContext.WithSession(session)
}
// Rate limit by UserID
@ -301,7 +301,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.Logger.Warn("Invalid CWS token", mlog.Err(err))
c.Err = err
} else {
c.AppContext.SetSession(session)
c.AppContext = c.AppContext.WithSession(session)
}
} else if token != "" && c.App.Channels().License() != nil && c.App.Channels().License().HasRemoteClusterService() && tokenLocation == app.TokenLocationRemoteClusterHeader {
// Get the remote cluster
@ -315,7 +315,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.Logger.Warn("Invalid remote cluster token", mlog.Err(err))
c.Err = err
} else {
c.AppContext.SetSession(session)
c.AppContext = c.AppContext.WithSession(session)
}
}
}
@ -327,7 +327,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
mlog.String("user_id", c.AppContext.Session().UserId),
mlog.String("method", r.Method),
)
c.AppContext.SetLogger(c.Logger)
c.AppContext = c.AppContext.WithLogger(c.Logger)
if c.Err == nil && h.RequireSession {
c.SessionRequired()
@ -354,7 +354,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// shape IP:PORT (it will be "@" in Linux, for example)
isLocalOrigin := !strings.Contains(r.RemoteAddr, ":")
if *c.App.Config().ServiceSettings.EnableLocalMode && isLocalOrigin {
c.AppContext.SetSession(&model.Session{Local: true})
c.AppContext = c.AppContext.WithSession(&model.Session{Local: true})
} else if !isLocalOrigin {
c.Err = model.NewAppError("", "api.context.local_origin_required.app_error", nil, "LocalOriginRequired", http.StatusUnauthorized)
}
@ -501,7 +501,7 @@ func (h *Handler) checkCSRFToken(c *Context, r *http.Request, token string, toke
}
if !csrfCheckPassed {
c.AppContext.SetSession(&model.Session{})
c.AppContext = c.AppContext.WithSession(&model.Session{})
c.Err = model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token+" Appears to be a CSRF attempt", http.StatusUnauthorized)
}
}

View File

@ -321,13 +321,14 @@ func completeOAuth(c *Context, w http.ResponseWriter, r *http.Request) {
} else if action == model.OAuthActionSSOToEmail {
redirectURL = app.GetProtocol(r) + "://" + r.Host + "/claim?email=" + url.QueryEscape(props["email"])
} else {
err = c.App.DoLogin(c.AppContext, w, r, user, "", isMobile, false, false)
session, err := c.App.DoLogin(c.AppContext, w, r, user, "", isMobile, false, false)
if err != nil {
err.Translate(c.AppContext.T)
mlog.Error(err.Error())
renderError(err)
return
}
c.AppContext = c.AppContext.WithSession(session)
// Old mobile version
if isMobile && !hasRedirectURL {

View File

@ -18,8 +18,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/i18n"
"github.com/mattermost/mattermost/server/public/shared/mlog"
"github.com/mattermost/mattermost/server/public/shared/request"
"github.com/mattermost/mattermost/server/v8/channels/utils"
"github.com/mattermost/mattermost/server/v8/einterfaces"
@ -397,20 +395,18 @@ func TestMobileLoginWithOAuth(t *testing.T) {
c := &Context{
App: th.App,
AppContext: th.Context,
Logger: th.TestLogger,
Params: &Params{
Service: "gitlab",
},
}
var siteURL = "http://localhost:8065"
siteURL := "http://localhost:8065"
th.App.UpdateConfig(func(cfg *model.Config) {
*cfg.ServiceSettings.SiteURL = siteURL
*cfg.GitLabSettings.Enable = true
})
translationFunc := i18n.GetUserTranslations("en")
c.AppContext.SetT(translationFunc)
c.Logger = th.TestLogger
provider := &MattermostTestProvider{}
einterfaces.RegisterOAuthProvider(model.ServiceGitlab, provider)
@ -617,14 +613,12 @@ func TestOAuthComplete_ErrorMessages(t *testing.T) {
c := &Context{
App: th.App,
AppContext: th.Context,
Logger: th.TestLogger,
Params: &Params{
Service: "gitlab",
},
}
translationFunc := i18n.GetUserTranslations("en")
c.AppContext.SetT(translationFunc)
c.Logger = mlog.CreateConsoleTestLogger(t)
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.GitLabSettings.Enable = true })
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.EnableOAuthServiceProvider = true })
provider := &MattermostTestProvider{}

View File

@ -178,11 +178,12 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) {
auditRec.AddMeta("obtained_user_id", user.Id)
c.LogAuditWithUserId(user.Id, "obtained user")
err = c.App.DoLogin(c.AppContext, w, r, user, "", isMobile, false, true)
session, err := c.App.DoLogin(c.AppContext, w, r, user, "", isMobile, false, true)
if err != nil {
handleError(err)
return
}
c.AppContext = c.AppContext.WithSession(session)
auditRec.Success()
c.LogAuditWithUserId(user.Id, "success")

View File

@ -33,7 +33,7 @@ var URL string
type TestHelper struct {
App app.AppIface
Context *request.Context
Context request.CTX
Server *app.Server
Web *Web
@ -141,7 +141,6 @@ func setupTestHelper(tb testing.TB, includeCacheLayer bool, options []app.Option
IncludeCacheLayer: includeCacheLayer,
TestLogger: testLogger,
}
th.Context.SetLogger(testLogger)
return th
}

View File

@ -138,10 +138,10 @@ func scheduleExportCmdF(command *cobra.Command, args []string) error {
defer cancel()
}
c := request.EmptyContext(a.Log())
c.SetContext(ctx)
var rctx request.CTX = request.EmptyContext(a.Log())
rctx = rctx.WithContext(ctx)
job, err := messageExportI.StartSynchronizeJob(c, startTime)
job, err := messageExportI.StartSynchronizeJob(rctx, startTime)
if err != nil || job.Status == model.JobStatusError || job.Status == model.JobStatusCanceled {
CommandPrintErrorln("ERROR: Message export job failed. Please check the server logs")
} else {

View File

@ -253,10 +253,10 @@ func (worker *BleveIndexerWorker) DoJob(job *model.Job) {
progress.TotalFilesCount = count
}
cancelContext := request.EmptyContext(worker.logger)
var cancelContext request.CTX = request.EmptyContext(worker.logger)
cancelCtx, cancelCancelWatcher := context.WithCancel(context.Background())
cancelWatcherChan := make(chan struct{}, 1)
cancelContext.SetContext(cancelCtx)
cancelContext = cancelContext.WithContext(cancelCtx)
go worker.jobServer.CancellationWatcher(cancelContext, job.Id, cancelWatcherChan)
defer cancelCancelWatcher()

View File

@ -22,8 +22,7 @@ type Context struct {
userAgent string
acceptLanguage string
logger mlog.LoggerIFace
context context.Context
context context.Context
}
func NewContext(ctx context.Context, requestId, ipAddress, xForwardedFor, path, userAgent, acceptLanguage string, t i18n.TranslateFunc) *Context {
@ -54,8 +53,8 @@ func TestContext(t testing.TB) *Context {
return EmptyContext(logger)
}
// Clone creates a shallow copy of Context, allowing clones to apply per-request changes.
func (c *Context) Clone() CTX {
// clone creates a shallow copy of Context, allowing clones to apply per-request changes.
func (c *Context) clone() *Context {
cCopy := *c
return &cCopy
}
@ -63,6 +62,9 @@ func (c *Context) Clone() CTX {
func (c *Context) T(translationID string, args ...any) string {
return c.t(translationID, args...)
}
func (c *Context) GetT() i18n.TranslateFunc {
return c.t
}
func (c *Context) Session() *model.Session {
return &c.session
}
@ -84,55 +86,71 @@ func (c *Context) UserAgent() string {
func (c *Context) AcceptLanguage() string {
return c.acceptLanguage
}
func (c *Context) Logger() mlog.LoggerIFace {
return c.logger
}
func (c *Context) Context() context.Context {
return c.context
}
func (c *Context) SetSession(s *model.Session) {
c.session = *s
func (c *Context) WithT(t i18n.TranslateFunc) CTX {
rctx := c.clone()
rctx.t = t
return rctx
}
func (c *Context) WithSession(s *model.Session) CTX {
rctx := c.clone()
rctx.session = *s
return rctx
}
func (c *Context) WithRequestId(s string) CTX {
rctx := c.clone()
rctx.requestId = s
return rctx
}
func (c *Context) WithIPAddress(s string) CTX {
rctx := c.clone()
rctx.ipAddress = s
return rctx
}
func (c *Context) WithXForwardedFor(s string) CTX {
rctx := c.clone()
rctx.xForwardedFor = s
return rctx
}
func (c *Context) WithPath(s string) CTX {
rctx := c.clone()
rctx.path = s
return rctx
}
func (c *Context) WithUserAgent(s string) CTX {
rctx := c.clone()
rctx.userAgent = s
return rctx
}
func (c *Context) WithAcceptLanguage(s string) CTX {
rctx := c.clone()
rctx.acceptLanguage = s
return rctx
}
func (c *Context) WithContext(ctx context.Context) CTX {
rctx := c.clone()
rctx.context = ctx
return rctx
}
func (c *Context) WithLogger(logger mlog.LoggerIFace) CTX {
rctx := c.clone()
rctx.logger = logger
return rctx
}
func (c *Context) SetT(t i18n.TranslateFunc) {
c.t = t
}
func (c *Context) SetRequestId(s string) {
c.requestId = s
}
func (c *Context) SetIPAddress(s string) {
c.ipAddress = s
}
func (c *Context) SetXForwardedFor(s string) {
c.xForwardedFor = s
}
func (c *Context) SetUserAgent(s string) {
c.userAgent = s
}
func (c *Context) SetAcceptLanguage(s string) {
c.acceptLanguage = s
}
func (c *Context) SetPath(s string) {
c.path = s
}
func (c *Context) SetContext(ctx context.Context) {
c.context = ctx
}
func (c *Context) GetT() i18n.TranslateFunc {
return c.t
}
func (c *Context) SetLogger(logger mlog.LoggerIFace) {
c.logger = logger
}
func (c *Context) Logger() mlog.LoggerIFace {
return c.logger
func (c *Context) With(f func(ctx CTX) CTX) CTX {
return f(c)
}
type CTX interface {
Clone() CTX
T(string, ...interface{}) string
GetT() i18n.TranslateFunc
Session() *model.Session
RequestId() string
IPAddress() string
@ -140,16 +158,17 @@ type CTX interface {
Path() string
UserAgent() string
AcceptLanguage() string
Context() context.Context
SetSession(s *model.Session)
SetT(i18n.TranslateFunc)
SetRequestId(string)
SetIPAddress(string)
SetUserAgent(string)
SetAcceptLanguage(string)
SetPath(string)
SetContext(ctx context.Context)
GetT() i18n.TranslateFunc
SetLogger(mlog.LoggerIFace)
Logger() mlog.LoggerIFace
Context() context.Context
WithT(i18n.TranslateFunc) CTX
WithSession(s *model.Session) CTX
WithRequestId(string) CTX
WithIPAddress(string) CTX
WithXForwardedFor(string) CTX
WithPath(string) CTX
WithUserAgent(string) CTX
WithAcceptLanguage(string) CTX
WithLogger(mlog.LoggerIFace) CTX
WithContext(ctx context.Context) CTX
With(func(ctx CTX) CTX) CTX
}