Fix racy test issues (#24971)

This commit is contained in:
Ben Schumacher 2023-11-06 12:26:17 +01:00 committed by GitHub
parent 366d1613b7
commit 486e836b83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 191 additions and 172 deletions

View File

@ -140,13 +140,12 @@ func setupTestHelper(dbStore store.Store, searchEngine *searchengine.Broker, ent
th := &TestHelper{ th := &TestHelper{
App: app.New(app.ServerConnector(s.Channels())), App: app.New(app.ServerConnector(s.Channels())),
Server: s, Server: s,
Context: request.EmptyContext(testLogger),
ConfigStore: configStore, ConfigStore: configStore,
IncludeCacheLayer: includeCache, IncludeCacheLayer: includeCache,
Context: request.EmptyContext(testLogger),
TestLogger: testLogger, TestLogger: testLogger,
LogBuffer: buffer, LogBuffer: buffer,
} }
th.Context.SetLogger(testLogger)
if s.Platform().SearchEngine != nil && s.Platform().SearchEngine.BleveEngine != nil && searchEngine != nil { if s.Platform().SearchEngine != nil && s.Platform().SearchEngine.BleveEngine != nil && searchEngine != nil {
searchEngine.BleveEngine = s.Platform().SearchEngine.BleveEngine searchEngine.BleveEngine = s.Platform().SearchEngine.BleveEngine

View File

@ -4,7 +4,6 @@
package api4 package api4
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"strconv" "strconv"
@ -1460,9 +1459,8 @@ func getChannelMember(c *Context, w http.ResponseWriter, r *http.Request) {
return return
} }
ctx := c.AppContext c.AppContext = c.AppContext.With(app.RequestContextWithMaster)
ctx.SetContext(app.WithMaster(ctx.Context())) member, err := c.App.GetChannelMember(c.AppContext, c.Params.ChannelId, c.Params.UserId)
member, err := c.App.GetChannelMember(ctx, c.Params.ChannelId, c.Params.UserId)
if err != nil { if err != nil {
c.Err = err c.Err = err
return return
@ -2004,7 +2002,7 @@ func channelMemberCountsByGroup(c *Context, w http.ResponseWriter, r *http.Reque
includeTimezones := r.URL.Query().Get("include_timezones") == "true" includeTimezones := r.URL.Query().Get("include_timezones") == "true"
channelMemberCounts, appErr := c.App.GetMemberCountsByGroup(app.WithMaster(context.Background()), c.Params.ChannelId, includeTimezones) channelMemberCounts, appErr := c.App.GetMemberCountsByGroup(c.AppContext.With(app.RequestContextWithMaster), c.Params.ChannelId, includeTimezones)
if appErr != nil { if appErr != nil {
c.Err = appErr c.Err = appErr
return return

View File

@ -195,7 +195,7 @@ func uploadRemoteData(c *Context, w http.ResponseWriter, r *http.Request) {
defer c.LogAuditRec(auditRec) defer c.LogAuditRec(auditRec)
audit.AddEventParameter(auditRec, "upload_id", c.Params.UploadId) audit.AddEventParameter(auditRec, "upload_id", c.Params.UploadId)
c.AppContext.SetContext(app.WithMaster(c.AppContext.Context())) c.AppContext = c.AppContext.With(app.RequestContextWithMaster)
us, err := c.App.GetUploadSession(c.AppContext, c.Params.UploadId) us, err := c.App.GetUploadSession(c.AppContext, c.Params.UploadId)
if err != nil { if err != nil {
c.Err = err c.Err = err

View File

@ -123,7 +123,7 @@ func uploadData(c *Context, w http.ResponseWriter, r *http.Request) {
defer c.LogAuditRec(auditRec) defer c.LogAuditRec(auditRec)
audit.AddEventParameter(auditRec, "upload_id", c.Params.UploadId) audit.AddEventParameter(auditRec, "upload_id", c.Params.UploadId)
c.AppContext.SetContext(app.WithMaster(c.AppContext.Context())) c.AppContext = c.AppContext.With(app.RequestContextWithMaster)
us, err := c.App.GetUploadSession(c.AppContext, c.Params.UploadId) us, err := c.App.GetUploadSession(c.AppContext, c.Params.UploadId)
if err != nil { if err != nil {
c.Err = err c.Err = err

View File

@ -1953,11 +1953,12 @@ func login(c *Context, w http.ResponseWriter, r *http.Request) {
c.LogAuditWithUserId(user.Id, "authenticated") c.LogAuditWithUserId(user.Id, "authenticated")
isMobileDevice := utils.IsMobileRequest(r) isMobileDevice := utils.IsMobileRequest(r)
err = c.App.DoLogin(c.AppContext, w, r, user, deviceId, isMobileDevice, false, false) session, err := c.App.DoLogin(c.AppContext, w, r, user, deviceId, isMobileDevice, false, false)
if err != nil { if err != nil {
c.Err = err c.Err = err
return return
} }
c.AppContext = c.AppContext.WithSession(session)
c.LogAuditWithUserId(user.Id, "success") c.LogAuditWithUserId(user.Id, "success")
@ -1995,11 +1996,12 @@ func loginWithDesktopToken(c *Context, w http.ResponseWriter, r *http.Request) {
return return
} }
err = c.App.DoLogin(c.AppContext, w, r, user, deviceId, false, false, false) session, err := c.App.DoLogin(c.AppContext, w, r, user, deviceId, false, false, false)
if err != nil { if err != nil {
c.Err = err c.Err = err
return return
} }
c.AppContext = c.AppContext.WithSession(session)
c.App.AttachSessionCookies(c.AppContext, w, r) c.App.AttachSessionCookies(c.AppContext, w, r)
@ -2048,12 +2050,13 @@ func loginCWS(c *Context, w http.ResponseWriter, r *http.Request) {
audit.AddEventParameterAuditable(auditRec, "user", user) audit.AddEventParameterAuditable(auditRec, "user", user)
c.LogAuditWithUserId(user.Id, "authenticated") c.LogAuditWithUserId(user.Id, "authenticated")
isMobileDevice := utils.IsMobileRequest(r) isMobileDevice := utils.IsMobileRequest(r)
err = c.App.DoLogin(c.AppContext, w, r, user, "", isMobileDevice, false, false) session, err := c.App.DoLogin(c.AppContext, w, r, user, "", isMobileDevice, false, false)
if err != nil { if err != nil {
c.LogErrorByCode(err) c.LogErrorByCode(err)
http.Redirect(w, r, *c.App.Config().ServiceSettings.SiteURL, http.StatusFound) http.Redirect(w, r, *c.App.Config().ServiceSettings.SiteURL, http.StatusFound)
return return
} }
c.AppContext = c.AppContext.WithSession(session)
c.LogAuditWithUserId(user.Id, "success") c.LogAuditWithUserId(user.Id, "success")
c.App.AttachSessionCookies(c.AppContext, w, r) c.App.AttachSessionCookies(c.AppContext, w, r)

View File

@ -553,7 +553,7 @@ type AppIface interface {
DoEmojisPermissionsMigration() DoEmojisPermissionsMigration()
DoGuestRolesCreationMigration() DoGuestRolesCreationMigration()
DoLocalRequest(c request.CTX, rawURL string, body []byte) (*http.Response, *model.AppError) DoLocalRequest(c request.CTX, rawURL string, body []byte) (*http.Response, *model.AppError)
DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile, isOAuthUser, isSaml bool) *model.AppError DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile, isOAuthUser, isSaml bool) (*model.Session, *model.AppError)
DoPostAction(c request.CTX, postID, actionId, userID, selectedOption string) (string, *model.AppError) DoPostAction(c request.CTX, postID, actionId, userID, selectedOption string) (string, *model.AppError)
DoPostActionWithCookie(c request.CTX, postID, actionId, userID, selectedOption string, cookie *model.PostActionCookie) (string, *model.AppError) DoPostActionWithCookie(c request.CTX, postID, actionId, userID, selectedOption string, cookie *model.PostActionCookie) (string, *model.AppError)
DoSystemConsoleRolesCreationMigration() DoSystemConsoleRolesCreationMigration()
@ -687,7 +687,7 @@ type AppIface interface {
GetLatestVersion(latestVersionUrl string) (*model.GithubReleaseInfo, *model.AppError) GetLatestVersion(latestVersionUrl string) (*model.GithubReleaseInfo, *model.AppError)
GetLogs(c request.CTX, page, perPage int) ([]string, *model.AppError) GetLogs(c request.CTX, page, perPage int) ([]string, *model.AppError)
GetLogsSkipSend(page, perPage int, logFilter *model.LogFilter) ([]string, *model.AppError) GetLogsSkipSend(page, perPage int, logFilter *model.LogFilter) ([]string, *model.AppError)
GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) GetMemberCountsByGroup(rctx request.CTX, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError)
GetMessageForNotification(post *model.Post, translateFunc i18n.TranslateFunc) string GetMessageForNotification(post *model.Post, translateFunc i18n.TranslateFunc) string
GetMultipleEmojiByName(c request.CTX, names []string) ([]*model.Emoji, *model.AppError) GetMultipleEmojiByName(c request.CTX, names []string) ([]*model.Emoji, *model.AppError)
GetNewUsersForTeamPage(teamID string, page, perPage int, asAdmin bool, viewRestrictions *model.ViewUsersRestrictions) ([]*model.User, *model.AppError) GetNewUsersForTeamPage(teamID string, page, perPage int, asAdmin bool, viewRestrictions *model.ViewUsersRestrictions) ([]*model.User, *model.AppError)

View File

@ -1666,7 +1666,9 @@ func (a *App) AddChannelMember(c request.CTX, userID string, channel *model.Chan
} }
} else { } else {
a.Srv().Go(func() { a.Srv().Go(func() {
a.PostAddToChannelMessage(c, userRequestor, user, channel, opts.PostRootID) if err := a.PostAddToChannelMessage(c, userRequestor, user, channel, opts.PostRootID); err != nil {
c.Logger().Error("Failed to post AddToChannel message", mlog.Err(err))
}
}) })
} }
@ -2363,7 +2365,9 @@ func (a *App) LeaveChannel(c request.CTX, channelID string, userID string) *mode
} }
a.Srv().Go(func() { a.Srv().Go(func() {
a.postLeaveChannelMessage(c, user, channel) if err := a.postLeaveChannelMessage(c, user, channel); err != nil {
c.Logger().Error("Failed to post LeaveChannel message", mlog.Err(err))
}
}) })
return nil return nil
@ -3458,8 +3462,8 @@ func (a *App) ClearChannelMembersCache(c request.CTX, channelID string) error {
return nil return nil
} }
func (a *App) GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) { func (a *App) GetMemberCountsByGroup(rctx request.CTX, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) {
channelMemberCounts, err := a.Srv().Store().Channel().GetMemberCountsByGroup(ctx, channelID, includeTimezones) channelMemberCounts, err := a.Srv().Store().Channel().GetMemberCountsByGroup(rctx.Context(), channelID, includeTimezones)
if err != nil { if err != nil {
return nil, model.NewAppError("GetMemberCountsByGroup", "app.channel.get_member_count.app_error", nil, "", http.StatusInternalServerError).Wrap(err) return nil, model.NewAppError("GetMemberCountsByGroup", "app.channel.get_member_count.app_error", nil, "", http.StatusInternalServerError).Wrap(err)
} }

View File

@ -2152,7 +2152,7 @@ func TestGetMemberCountsByGroup(t *testing.T) {
mockChannelStore.On("GetMemberCountsByGroup", context.Background(), "channelID", true).Return(cmc, nil) mockChannelStore.On("GetMemberCountsByGroup", context.Background(), "channelID", true).Return(cmc, nil)
mockStore.On("Channel").Return(&mockChannelStore) mockStore.On("Channel").Return(&mockChannelStore)
mockStore.On("GetDBSchemaVersion").Return(1, nil) mockStore.On("GetDBSchemaVersion").Return(1, nil)
resp, err := th.App.GetMemberCountsByGroup(context.Background(), "channelID", true) resp, err := th.App.GetMemberCountsByGroup(th.Context, "channelID", true)
require.Nil(t, err) require.Nil(t, err)
require.ElementsMatch(t, cmc, resp) require.ElementsMatch(t, cmc, resp)
} }

View File

@ -34,7 +34,7 @@ func (a *App) SaveComplianceReport(rctx request.CTX, job *model.Compliance) (*mo
job.Type = model.ComplianceTypeAdhoc job.Type = model.ComplianceTypeAdhoc
rctx.SetLogger(rctx.Logger().With(job.LoggerFields()...)) rctx = rctx.WithLogger(rctx.Logger().With(job.LoggerFields()...))
job, err := a.Srv().Store().Compliance().Save(job) job, err := a.Srv().Store().Compliance().Save(job)
if err != nil { if err != nil {
@ -48,11 +48,10 @@ func (a *App) SaveComplianceReport(rctx request.CTX, job *model.Compliance) (*mo
} }
jCopy := job.DeepCopy() jCopy := job.DeepCopy()
crctx := rctx.Clone()
a.Srv().Go(func() { a.Srv().Go(func() {
err := a.Compliance().RunComplianceJob(crctx, jCopy) err := a.Compliance().RunComplianceJob(rctx, jCopy)
if err != nil { if err != nil {
crctx.Logger().Warn("Error running compliance job", mlog.Err(err)) rctx.Logger().Warn("Error running compliance job", mlog.Err(err))
} }
}) })

View File

@ -4,16 +4,14 @@
package app package app
import ( import (
"context"
"github.com/mattermost/mattermost/server/public/plugin" "github.com/mattermost/mattermost/server/public/plugin"
"github.com/mattermost/mattermost/server/public/shared/request" "github.com/mattermost/mattermost/server/public/shared/request"
"github.com/mattermost/mattermost/server/v8/channels/store/sqlstore" "github.com/mattermost/mattermost/server/v8/channels/store/sqlstore"
) )
// WithMaster adds the context value that master DB should be selected for this request. // RequestContextWithMaster adds the context value that master DB should be selected for this request.
func WithMaster(ctx context.Context) context.Context { func RequestContextWithMaster(c request.CTX) request.CTX {
return sqlstore.WithMaster(ctx) return sqlstore.RequestContextWithMaster(c)
} }
func pluginContext(c request.CTX) *plugin.Context { func pluginContext(c request.CTX) *plugin.Context {

View File

@ -796,11 +796,10 @@ func (a *App) UploadFileX(c request.CTX, channelID, name string, input io.Reader
if *a.Config().FileSettings.ExtractContent { if *a.Config().FileSettings.ExtractContent {
infoCopy := *t.fileinfo infoCopy := *t.fileinfo
crctx := c.Clone()
a.Srv().GoBuffered(func() { a.Srv().GoBuffered(func() {
err := a.ExtractContentFromFileInfo(crctx, &infoCopy) err := a.ExtractContentFromFileInfo(c, &infoCopy)
if err != nil { if err != nil {
crctx.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id)) c.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id))
} }
}) })
} }
@ -1048,11 +1047,10 @@ func (a *App) DoUploadFileExpectModification(c request.CTX, now time.Time, rawTe
if *a.Config().FileSettings.ExtractContent { if *a.Config().FileSettings.ExtractContent {
infoCopy := *info infoCopy := *info
crctx := c.Clone()
a.Srv().GoBuffered(func() { a.Srv().GoBuffered(func() {
err := a.ExtractContentFromFileInfo(crctx, &infoCopy) err := a.ExtractContentFromFileInfo(c, &infoCopy)
if err != nil { if err != nil {
crctx.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id)) c.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id))
} }
}) })
} }

View File

@ -99,7 +99,6 @@ func setupTestHelper(dbStore store.Store, enterprise bool, includeCacheLayer boo
IncludeCacheLayer: includeCacheLayer, IncludeCacheLayer: includeCacheLayer,
ConfigStore: configStore, ConfigStore: configStore,
} }
th.Context.SetLogger(testLogger)
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.TeamSettings.MaxUsersPerTeam = 50 }) th.App.UpdateConfig(func(cfg *model.Config) { *cfg.TeamSettings.MaxUsersPerTeam = 50 })
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.RateLimitSettings.Enable = false }) th.App.UpdateConfig(func(cfg *model.Config) { *cfg.RateLimitSettings.Enable = false })

View File

@ -268,7 +268,7 @@ func (a *App) bulkImport(c request.CTX, jsonlReader io.Reader, attachmentsReader
linesChan = make(chan imports.LineImportWorkerData, workers) linesChan = make(chan imports.LineImportWorkerData, workers)
for i := 0; i < workers; i++ { for i := 0; i < workers; i++ {
wg.Add(1) wg.Add(1)
go a.bulkImportWorker(c.Clone(), dryRun, &wg, linesChan, errorsChan) go a.bulkImportWorker(c, dryRun, &wg, linesChan, errorsChan)
} }
} }

View File

@ -156,7 +156,7 @@ func (a *App) GetUserForLogin(c request.CTX, id, loginId string) (*model.User, *
return nil, model.NewAppError("GetUserForLogin", "store.sql_user.get_for_login.app_error", nil, "", http.StatusBadRequest) return nil, model.NewAppError("GetUserForLogin", "store.sql_user.get_for_login.app_error", nil, "", http.StatusBadRequest)
} }
func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile, isOAuthUser, isSaml bool) *model.AppError { func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile, isOAuthUser, isSaml bool) (*model.Session, *model.AppError) {
var rejectionReason string var rejectionReason string
pluginContext := pluginContext(c) pluginContext := pluginContext(c)
a.ch.RunMultiHook(func(hooks plugin.Hooks) bool { a.ch.RunMultiHook(func(hooks plugin.Hooks) bool {
@ -165,7 +165,7 @@ func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, use
}, plugin.UserWillLogInID) }, plugin.UserWillLogInID)
if rejectionReason != "" { if rejectionReason != "" {
return model.NewAppError("DoLogin", "Login rejected by plugin: "+rejectionReason, nil, "", http.StatusBadRequest) return nil, model.NewAppError("DoLogin", "Login rejected by plugin: "+rejectionReason, nil, "", http.StatusBadRequest)
} }
session := &model.Session{UserId: user.Id, Roles: user.GetRawRoles(), DeviceId: deviceID, IsOAuth: false, Props: map[string]string{ session := &model.Session{UserId: user.Id, Roles: user.GetRawRoles(), DeviceId: deviceID, IsOAuth: false, Props: map[string]string{
@ -181,7 +181,7 @@ func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, use
// A special case where we logout of all other sessions with the same Id // A special case where we logout of all other sessions with the same Id
if err := a.RevokeSessionsForDeviceId(c, user.Id, deviceID, ""); err != nil { if err := a.RevokeSessionsForDeviceId(c, user.Id, deviceID, ""); err != nil {
err.StatusCode = http.StatusInternalServerError err.StatusCode = http.StatusInternalServerError
return err return nil, err
} }
} else if isMobile { } else if isMobile {
a.ch.srv.platform.SetSessionExpireInHours(session, *a.Config().ServiceSettings.SessionLengthMobileInHours) a.ch.srv.platform.SetSessionExpireInHours(session, *a.Config().ServiceSettings.SessionLengthMobileInHours)
@ -210,12 +210,12 @@ func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, use
var err *model.AppError var err *model.AppError
if session, err = a.CreateSession(c, session); err != nil { if session, err = a.CreateSession(c, session); err != nil {
err.StatusCode = http.StatusInternalServerError err.StatusCode = http.StatusInternalServerError
return err return nil, err
} }
w.Header().Set(model.HeaderToken, session.Token) w.Header().Set(model.HeaderToken, session.Token)
c.SetSession(session) c = c.WithSession(session)
if a.Srv().License() != nil && *a.Srv().License().Features.LDAP && a.Ldap() != nil { if a.Srv().License() != nil && *a.Srv().License().Features.LDAP && a.Ldap() != nil {
userVal := *user userVal := *user
sessionVal := *session sessionVal := *session
@ -231,7 +231,7 @@ func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, use
}, plugin.UserHasLoggedInID) }, plugin.UserHasLoggedInID)
}) })
return nil return session, nil
} }
func (a *App) AttachCloudSessionCookie(c request.CTX, w http.ResponseWriter, r *http.Request) { func (a *App) AttachCloudSessionCookie(c request.CTX, w http.ResponseWriter, r *http.Request) {

View File

@ -3828,7 +3828,7 @@ func (a *OpenTracingAppLayer) DoLocalRequest(c request.CTX, rawURL string, body
return resultVar0, resultVar1 return resultVar0, resultVar1
} }
func (a *OpenTracingAppLayer) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile bool, isOAuthUser bool, isSaml bool) *model.AppError { func (a *OpenTracingAppLayer) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile bool, isOAuthUser bool, isSaml bool) (*model.Session, *model.AppError) {
origCtx := a.ctx origCtx := a.ctx
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.DoLogin") span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.DoLogin")
@ -3840,14 +3840,14 @@ func (a *OpenTracingAppLayer) DoLogin(c request.CTX, w http.ResponseWriter, r *h
}() }()
defer span.Finish() defer span.Finish()
resultVar0 := a.app.DoLogin(c, w, r, user, deviceID, isMobile, isOAuthUser, isSaml) resultVar0, resultVar1 := a.app.DoLogin(c, w, r, user, deviceID, isMobile, isOAuthUser, isSaml)
if resultVar0 != nil { if resultVar1 != nil {
span.LogFields(spanlog.Error(resultVar0)) span.LogFields(spanlog.Error(resultVar1))
ext.Error.Set(span, true) ext.Error.Set(span, true)
} }
return resultVar0 return resultVar0, resultVar1
} }
func (a *OpenTracingAppLayer) DoPermissionsMigrations() error { func (a *OpenTracingAppLayer) DoPermissionsMigrations() error {
@ -7320,7 +7320,7 @@ func (a *OpenTracingAppLayer) GetMarketplacePlugins(filter *model.MarketplacePlu
return resultVar0, resultVar1 return resultVar0, resultVar1
} }
func (a *OpenTracingAppLayer) GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) { func (a *OpenTracingAppLayer) GetMemberCountsByGroup(rctx request.CTX, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) {
origCtx := a.ctx origCtx := a.ctx
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetMemberCountsByGroup") span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetMemberCountsByGroup")
@ -7332,7 +7332,7 @@ func (a *OpenTracingAppLayer) GetMemberCountsByGroup(ctx context.Context, channe
}() }()
defer span.Finish() defer span.Finish()
resultVar0, resultVar1 := a.app.GetMemberCountsByGroup(ctx, channelID, includeTimezones) resultVar0, resultVar1 := a.app.GetMemberCountsByGroup(rctx, channelID, includeTimezones)
if resultVar1 != nil { if resultVar1 != nil {
span.LogFields(spanlog.Error(resultVar1)) span.LogFields(spanlog.Error(resultVar1))

View File

@ -1263,7 +1263,7 @@ func (api *PluginAPI) UploadData(us *model.UploadSession, rd io.Reader) (*model.
func (api *PluginAPI) GetUploadSession(uploadID string) (*model.UploadSession, error) { func (api *PluginAPI) GetUploadSession(uploadID string) (*model.UploadSession, error) {
// We want to fetch from master DB to avoid a potential read-after-write on the plugin side. // We want to fetch from master DB to avoid a potential read-after-write on the plugin side.
api.ctx.SetContext(WithMaster(api.ctx.Context())) api.ctx = api.ctx.With(RequestContextWithMaster)
fi, err := api.app.GetUploadSession(api.ctx, uploadID) fi, err := api.app.GetUploadSession(api.ctx, uploadID)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -727,9 +727,10 @@ func TestUserWillLogIn_Blocked(t *testing.T) {
r := &http.Request{} r := &http.Request{}
w := httptest.NewRecorder() w := httptest.NewRecorder()
err = th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false) session, err := th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false)
assert.Contains(t, err.Id, "Login rejected by plugin", "Expected Login rejected by plugin, got %s", err.Id) assert.Contains(t, err.Id, "Login rejected by plugin", "Expected Login rejected by plugin, got %s", err.Id)
assert.Nil(t, session)
} }
func TestUserWillLogInIn_Passed(t *testing.T) { func TestUserWillLogInIn_Passed(t *testing.T) {
@ -766,10 +767,11 @@ func TestUserWillLogInIn_Passed(t *testing.T) {
r := &http.Request{} r := &http.Request{}
w := httptest.NewRecorder() w := httptest.NewRecorder()
err = th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false) session, err := th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false)
assert.Nil(t, err, "Expected nil, got %s", err) assert.Nil(t, err, "Expected nil, got %s", err)
assert.Equal(t, th.Context.Session().UserId, th.BasicUser.Id) require.NotNil(t, session)
assert.Equal(t, session.UserId, th.BasicUser.Id)
} }
func TestUserHasLoggedIn(t *testing.T) { func TestUserHasLoggedIn(t *testing.T) {
@ -807,9 +809,10 @@ func TestUserHasLoggedIn(t *testing.T) {
r := &http.Request{} r := &http.Request{}
w := httptest.NewRecorder() w := httptest.NewRecorder()
err = th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false) session, err := th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false)
assert.Nil(t, err, "Expected nil, got %s", err) assert.Nil(t, err, "Expected nil, got %s", err)
assert.NotNil(t, session)
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)

View File

@ -532,7 +532,7 @@ func (a *App) handlePostEvents(c request.CTX, post *model.Post, user *model.User
a.Srv().Go(func() { a.Srv().Go(func() {
_, err := a.SendAutoResponseIfNecessary(c, channel, user, post) _, err := a.SendAutoResponseIfNecessary(c, channel, user, post)
if err != nil { if err != nil {
mlog.Error("Failed to send auto response", mlog.String("user_id", user.Id), mlog.String("post_id", post.Id), mlog.Err(err)) c.Logger().Error("Failed to send auto response", mlog.String("user_id", user.Id), mlog.String("post_id", post.Id), mlog.Err(err))
} }
}) })
} }
@ -540,7 +540,7 @@ func (a *App) handlePostEvents(c request.CTX, post *model.Post, user *model.User
if triggerWebhooks { if triggerWebhooks {
a.Srv().Go(func() { a.Srv().Go(func() {
if err := a.handleWebhookEvents(c, post, team, channel, user); err != nil { if err := a.handleWebhookEvents(c, post, team, channel, user); err != nil {
mlog.Error(err.Error()) c.Logger().Error("Failed to handle webhook event", mlog.Err(err))
} }
}) })
} }
@ -1378,7 +1378,7 @@ func (a *App) DeletePost(c request.CTX, postID, deleteByID string) (*model.Post,
a.Srv().Go(func() { a.Srv().Go(func() {
if err = a.RemoveNotifications(c, post, channel); err != nil { if err = a.RemoveNotifications(c, post, channel); err != nil {
a.Log().Error("DeletePost failed to delete notification", mlog.Err(err)) c.Logger().Error("DeletePost failed to delete notification", mlog.Err(err))
} }
}) })

View File

@ -19,6 +19,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/request"
"github.com/mattermost/mattermost/server/v8/channels/app/email" "github.com/mattermost/mattermost/server/v8/channels/app/email"
emailmocks "github.com/mattermost/mattermost/server/v8/channels/app/email/mocks" emailmocks "github.com/mattermost/mattermost/server/v8/channels/app/email/mocks"
"github.com/mattermost/mattermost/server/v8/channels/app/teams" "github.com/mattermost/mattermost/server/v8/channels/app/teams"
@ -1027,7 +1028,12 @@ func TestLeaveTeamPanic(t *testing.T) {
mockLicenseStore.On("Get", "").Return(&model.LicenseRecord{}, nil) mockLicenseStore.On("Get", "").Return(&model.LicenseRecord{}, nil)
mockTeamStore := mocks.TeamStore{} mockTeamStore := mocks.TeamStore{}
mockTeamStore.On("GetMember", sqlstore.RequestContextWithMaster(th.Context), "myteam", "userID").Return(&model.TeamMember{TeamId: "myteam", UserId: "userID"}, nil) mockTeamStore.On("GetMember", mock.AnythingOfType("*request.Context"), "myteam", "userID").Return(&model.TeamMember{TeamId: "myteam", UserId: "userID"}, nil).Run(func(args mock.Arguments) {
c, ok := args[0].(request.CTX)
require.True(t, ok)
sqlstore.HasMaster(c.Context())
})
mockTeamStore.On("UpdateMember", mock.Anything).Return(nil, errors.New("repro error")) // This is the line that triggers the error mockTeamStore.On("UpdateMember", mock.Anything).Return(nil, errors.New("repro error")) // This is the line that triggers the error
mockStore.On("Channel").Return(&mockChannelStore) mockStore.On("Channel").Return(&mockChannelStore)

View File

@ -203,7 +203,7 @@ func (a *App) UploadData(c request.CTX, us *model.UploadSession, rd io.Reader) (
}() }()
// fetch the session from store to check for inconsistencies. // fetch the session from store to check for inconsistencies.
c.SetContext(WithMaster(c.Context())) c = c.With(RequestContextWithMaster)
if storedSession, err := a.GetUploadSession(c, us.Id); err != nil { if storedSession, err := a.GetUploadSession(c, us.Id); err != nil {
return nil, err return nil, err
} else if us.FileOffset != storedSession.FileOffset { } else if us.FileOffset != storedSession.FileOffset {
@ -318,11 +318,10 @@ func (a *App) UploadData(c request.CTX, us *model.UploadSession, rd io.Reader) (
if *a.Config().FileSettings.ExtractContent { if *a.Config().FileSettings.ExtractContent {
infoCopy := *info infoCopy := *info
crctx := c.Clone()
a.Srv().Go(func() { a.Srv().Go(func() {
err := a.ExtractContentFromFileInfo(crctx, &infoCopy) err := a.ExtractContentFromFileInfo(c, &infoCopy)
if err != nil { if err != nil {
crctx.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id)) c.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id))
} }
}) })
} }

View File

@ -5,6 +5,7 @@ package jobs
import ( import (
"net/http" "net/http"
"sync/atomic"
"time" "time"
"github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/model"
@ -33,8 +34,9 @@ type BatchMigrationWorker struct {
store store.Store store store.Store
app BatchMigrationWorkerAppIFace app BatchMigrationWorkerAppIFace
stop chan bool stop chan struct{}
stopped chan bool stopped chan bool
closed atomic.Bool
jobs chan model.Job jobs chan model.Job
migrationKey string migrationKey string
@ -49,7 +51,7 @@ func MakeBatchMigrationWorker(jobServer *JobServer, store store.Store, app Batch
logger: jobServer.Logger().With(mlog.String("worker_name", migrationKey)), logger: jobServer.Logger().With(mlog.String("worker_name", migrationKey)),
store: store, store: store,
app: app, app: app,
stop: make(chan bool, 1), stop: make(chan struct{}),
stopped: make(chan bool, 1), stopped: make(chan bool, 1),
jobs: make(chan model.Job), jobs: make(chan model.Job),
migrationKey: migrationKey, migrationKey: migrationKey,
@ -64,7 +66,9 @@ func (worker *BatchMigrationWorker) Run() {
worker.logger.Debug("Worker started") worker.logger.Debug("Worker started")
// We have to re-assign the stop channel again, because // We have to re-assign the stop channel again, because
// it might happen that the job was restarted due to a config change. // it might happen that the job was restarted due to a config change.
worker.stop = make(chan bool, 1) if worker.closed.CompareAndSwap(true, false) {
worker.stop = make(chan struct{})
}
defer func() { defer func() {
worker.logger.Debug("Worker finished") worker.logger.Debug("Worker finished")
@ -84,6 +88,11 @@ func (worker *BatchMigrationWorker) Run() {
// Stop interrupts the worker even if the migration has not yet completed. // Stop interrupts the worker even if the migration has not yet completed.
func (worker *BatchMigrationWorker) Stop() { func (worker *BatchMigrationWorker) Stop() {
// Set to close, and if already closed before, then return.
if !worker.closed.CompareAndSwap(false, true) {
return
}
worker.logger.Debug("Worker stopping") worker.logger.Debug("Worker stopping")
close(worker.stop) close(worker.stop)
<-worker.stopped <-worker.stopped

View File

@ -88,7 +88,6 @@ func setupTestHelper(dbStore store.Store, enterprise bool, includeCacheLayer boo
IncludeCacheLayer: includeCacheLayer, IncludeCacheLayer: includeCacheLayer,
ConfigStore: configStore, ConfigStore: configStore,
} }
th.Context.SetLogger(testLogger)
prevListenAddress := *th.App.Config().ServiceSettings.ListenAddress prevListenAddress := *th.App.Config().ServiceSettings.ListenAddress
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.ListenAddress = "localhost:0" }) th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.ListenAddress = "localhost:0" })

View File

@ -100,10 +100,10 @@ func (worker *Worker) DoJob(job *model.Job) {
return return
} }
cancelContext := request.EmptyContext(worker.logger) var cancelContext request.CTX = request.EmptyContext(worker.logger)
cancelCtx, cancelCancelWatcher := context.WithCancel(context.Background()) cancelCtx, cancelCancelWatcher := context.WithCancel(context.Background())
cancelWatcherChan := make(chan struct{}, 1) cancelWatcherChan := make(chan struct{}, 1)
cancelContext.SetContext(cancelCtx) cancelContext = cancelContext.WithContext(cancelCtx)
go worker.jobServer.CancellationWatcher(cancelContext, job.Id, cancelWatcherChan) go worker.jobServer.CancellationWatcher(cancelContext, job.Id, cancelWatcherChan)
defer cancelCancelWatcher() defer cancelCancelWatcher()

View File

@ -29,7 +29,7 @@ type S3PathMigrationWorker struct {
store store.Store store store.Store
fileBackend *filestore.S3FileBackend fileBackend *filestore.S3FileBackend
stop chan bool stop chan struct{}
stopped chan bool stopped chan bool
jobs chan model.Job jobs chan model.Job
} }
@ -45,7 +45,7 @@ func MakeWorker(jobServer *jobs.JobServer, store store.Store, fileBackend filest
logger: jobServer.Logger().With(mlog.String("worker_name", workerName)), logger: jobServer.Logger().With(mlog.String("worker_name", workerName)),
store: store, store: store,
fileBackend: s3Backend, fileBackend: s3Backend,
stop: make(chan bool, 1), stop: make(chan struct{}),
stopped: make(chan bool, 1), stopped: make(chan bool, 1),
jobs: make(chan model.Job), jobs: make(chan model.Job),
} }
@ -56,7 +56,7 @@ func (worker *S3PathMigrationWorker) Run() {
worker.logger.Debug("Worker started") worker.logger.Debug("Worker started")
// We have to re-assign the stop channel again, because // We have to re-assign the stop channel again, because
// it might happen that the job was restarted due to a config change. // it might happen that the job was restarted due to a config change.
worker.stop = make(chan bool, 1) worker.stop = make(chan struct{}, 1)
defer func() { defer func() {
worker.logger.Debug("Worker finished") worker.logger.Debug("Worker finished")

View File

@ -31,7 +31,7 @@ func WithMaster(ctx context.Context) context.Context {
// RequestContextWithMaster adds the context value that master DB should be selected for this request. // RequestContextWithMaster adds the context value that master DB should be selected for this request.
func RequestContextWithMaster(c request.CTX) request.CTX { func RequestContextWithMaster(c request.CTX) request.CTX {
ctx := WithMaster(c.Context()) ctx := WithMaster(c.Context())
c.SetContext(ctx) c = c.WithContext(ctx)
return c return c
} }

View File

@ -26,39 +26,30 @@ func TestRequestContextWithMaster(t *testing.T) {
assert.True(t, HasMaster(rctx.Context())) assert.True(t, HasMaster(rctx.Context()))
}) })
t.Run("directly assigning does cause the child to alter the parent", func(t *testing.T) { t.Run("values get copied from original context", func(t *testing.T) {
var rctx request.CTX = request.TestContext(t)
rctxClone := rctx
rctxClone = RequestContextWithMaster(rctxClone)
assert.True(t, HasMaster(rctx.Context()))
assert.True(t, HasMaster(rctxClone.Context()))
})
t.Run("values get copied from parent", func(t *testing.T) {
var rctx request.CTX = request.TestContext(t) var rctx request.CTX = request.TestContext(t)
rctx = RequestContextWithMaster(rctx) rctx = RequestContextWithMaster(rctx)
rctxClone := rctx.Clone() rctxCopy := rctx
assert.True(t, HasMaster(rctx.Context())) assert.True(t, HasMaster(rctx.Context()))
assert.True(t, HasMaster(rctxClone.Context())) assert.True(t, HasMaster(rctxCopy.Context()))
}) })
t.Run("changing the child does not alter the parent", func(t *testing.T) { t.Run("directly assigning does not cause the copy to alter the original context", func(t *testing.T) {
var rctx request.CTX = request.TestContext(t) var rctx request.CTX = request.TestContext(t)
rctxClone := rctx.Clone() rctxCopy := rctx
rctxClone = RequestContextWithMaster(rctxClone) rctxCopy = RequestContextWithMaster(rctxCopy)
assert.False(t, HasMaster(rctx.Context())) assert.False(t, HasMaster(rctx.Context()))
assert.True(t, HasMaster(rctxClone.Context())) assert.True(t, HasMaster(rctxCopy.Context()))
}) })
t.Run("changing the parent does not alter the child", func(t *testing.T) { t.Run("directly assigning does not cause the original context to alter the copy", func(t *testing.T) {
var rctx request.CTX = request.TestContext(t) var rctx request.CTX = request.TestContext(t)
rctxClone := rctx.Clone() rctxCopy := rctx
rctx = RequestContextWithMaster(rctx) rctx = RequestContextWithMaster(rctx)
assert.True(t, HasMaster(rctx.Context())) assert.True(t, HasMaster(rctx.Context()))
assert.False(t, HasMaster(rctxClone.Context())) assert.False(t, HasMaster(rctxCopy.Context()))
}) })
} }

View File

@ -20,7 +20,7 @@ import (
type Context struct { type Context struct {
App app.AppIface App app.AppIface
AppContext *request.Context AppContext request.CTX
Logger *mlog.Logger Logger *mlog.Logger
Params *Params Params *Params
Err *model.AppError Err *model.AppError

View File

@ -71,7 +71,7 @@ func TestMfaRequired(t *testing.T) {
th.App.Srv().SetLicense(model.NewTestLicense("mfa")) th.App.Srv().SetLicense(model.NewTestLicense("mfa"))
th.Context.SetSession(&model.Session{Id: "abc", UserId: "userid"}) th.Context = th.Context.WithSession(&model.Session{Id: "abc", UserId: "userid"})
th.App.UpdateConfig(func(cfg *model.Config) { th.App.UpdateConfig(func(cfg *model.Config) {
*cfg.AnnouncementSettings.UserNoticesEnabled = false *cfg.AnnouncementSettings.UserNoticesEnabled = false

View File

@ -205,7 +205,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
span.Finish() span.Finish()
}() }()
c.AppContext.SetContext(ctx) c.AppContext = c.AppContext.WithContext(ctx)
tmpSrv := *c.App.Srv() tmpSrv := *c.App.Srv()
tmpSrv.SetStore(opentracinglayer.New(c.App.Srv().Store(), ctx)) tmpSrv.SetStore(opentracinglayer.New(c.App.Srv().Store(), ctx))
@ -285,7 +285,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} else if !session.IsOAuth && tokenLocation == app.TokenLocationQueryString { } else if !session.IsOAuth && tokenLocation == app.TokenLocationQueryString {
c.Err = model.NewAppError("ServeHTTP", "api.context.token_provided.app_error", nil, "token="+token, http.StatusUnauthorized) c.Err = model.NewAppError("ServeHTTP", "api.context.token_provided.app_error", nil, "token="+token, http.StatusUnauthorized)
} else { } else {
c.AppContext.SetSession(session) c.AppContext = c.AppContext.WithSession(session)
} }
// Rate limit by UserID // Rate limit by UserID
@ -301,7 +301,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.Logger.Warn("Invalid CWS token", mlog.Err(err)) c.Logger.Warn("Invalid CWS token", mlog.Err(err))
c.Err = err c.Err = err
} else { } else {
c.AppContext.SetSession(session) c.AppContext = c.AppContext.WithSession(session)
} }
} else if token != "" && c.App.Channels().License() != nil && c.App.Channels().License().HasRemoteClusterService() && tokenLocation == app.TokenLocationRemoteClusterHeader { } else if token != "" && c.App.Channels().License() != nil && c.App.Channels().License().HasRemoteClusterService() && tokenLocation == app.TokenLocationRemoteClusterHeader {
// Get the remote cluster // Get the remote cluster
@ -315,7 +315,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.Logger.Warn("Invalid remote cluster token", mlog.Err(err)) c.Logger.Warn("Invalid remote cluster token", mlog.Err(err))
c.Err = err c.Err = err
} else { } else {
c.AppContext.SetSession(session) c.AppContext = c.AppContext.WithSession(session)
} }
} }
} }
@ -327,7 +327,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
mlog.String("user_id", c.AppContext.Session().UserId), mlog.String("user_id", c.AppContext.Session().UserId),
mlog.String("method", r.Method), mlog.String("method", r.Method),
) )
c.AppContext.SetLogger(c.Logger) c.AppContext = c.AppContext.WithLogger(c.Logger)
if c.Err == nil && h.RequireSession { if c.Err == nil && h.RequireSession {
c.SessionRequired() c.SessionRequired()
@ -354,7 +354,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// shape IP:PORT (it will be "@" in Linux, for example) // shape IP:PORT (it will be "@" in Linux, for example)
isLocalOrigin := !strings.Contains(r.RemoteAddr, ":") isLocalOrigin := !strings.Contains(r.RemoteAddr, ":")
if *c.App.Config().ServiceSettings.EnableLocalMode && isLocalOrigin { if *c.App.Config().ServiceSettings.EnableLocalMode && isLocalOrigin {
c.AppContext.SetSession(&model.Session{Local: true}) c.AppContext = c.AppContext.WithSession(&model.Session{Local: true})
} else if !isLocalOrigin { } else if !isLocalOrigin {
c.Err = model.NewAppError("", "api.context.local_origin_required.app_error", nil, "LocalOriginRequired", http.StatusUnauthorized) c.Err = model.NewAppError("", "api.context.local_origin_required.app_error", nil, "LocalOriginRequired", http.StatusUnauthorized)
} }
@ -501,7 +501,7 @@ func (h *Handler) checkCSRFToken(c *Context, r *http.Request, token string, toke
} }
if !csrfCheckPassed { if !csrfCheckPassed {
c.AppContext.SetSession(&model.Session{}) c.AppContext = c.AppContext.WithSession(&model.Session{})
c.Err = model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token+" Appears to be a CSRF attempt", http.StatusUnauthorized) c.Err = model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token+" Appears to be a CSRF attempt", http.StatusUnauthorized)
} }
} }

View File

@ -321,13 +321,14 @@ func completeOAuth(c *Context, w http.ResponseWriter, r *http.Request) {
} else if action == model.OAuthActionSSOToEmail { } else if action == model.OAuthActionSSOToEmail {
redirectURL = app.GetProtocol(r) + "://" + r.Host + "/claim?email=" + url.QueryEscape(props["email"]) redirectURL = app.GetProtocol(r) + "://" + r.Host + "/claim?email=" + url.QueryEscape(props["email"])
} else { } else {
err = c.App.DoLogin(c.AppContext, w, r, user, "", isMobile, false, false) session, err := c.App.DoLogin(c.AppContext, w, r, user, "", isMobile, false, false)
if err != nil { if err != nil {
err.Translate(c.AppContext.T) err.Translate(c.AppContext.T)
mlog.Error(err.Error()) mlog.Error(err.Error())
renderError(err) renderError(err)
return return
} }
c.AppContext = c.AppContext.WithSession(session)
// Old mobile version // Old mobile version
if isMobile && !hasRedirectURL { if isMobile && !hasRedirectURL {

View File

@ -18,8 +18,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/i18n"
"github.com/mattermost/mattermost/server/public/shared/mlog"
"github.com/mattermost/mattermost/server/public/shared/request" "github.com/mattermost/mattermost/server/public/shared/request"
"github.com/mattermost/mattermost/server/v8/channels/utils" "github.com/mattermost/mattermost/server/v8/channels/utils"
"github.com/mattermost/mattermost/server/v8/einterfaces" "github.com/mattermost/mattermost/server/v8/einterfaces"
@ -397,20 +395,18 @@ func TestMobileLoginWithOAuth(t *testing.T) {
c := &Context{ c := &Context{
App: th.App, App: th.App,
AppContext: th.Context, AppContext: th.Context,
Logger: th.TestLogger,
Params: &Params{ Params: &Params{
Service: "gitlab", Service: "gitlab",
}, },
} }
var siteURL = "http://localhost:8065" siteURL := "http://localhost:8065"
th.App.UpdateConfig(func(cfg *model.Config) { th.App.UpdateConfig(func(cfg *model.Config) {
*cfg.ServiceSettings.SiteURL = siteURL *cfg.ServiceSettings.SiteURL = siteURL
*cfg.GitLabSettings.Enable = true *cfg.GitLabSettings.Enable = true
}) })
translationFunc := i18n.GetUserTranslations("en")
c.AppContext.SetT(translationFunc)
c.Logger = th.TestLogger
provider := &MattermostTestProvider{} provider := &MattermostTestProvider{}
einterfaces.RegisterOAuthProvider(model.ServiceGitlab, provider) einterfaces.RegisterOAuthProvider(model.ServiceGitlab, provider)
@ -617,14 +613,12 @@ func TestOAuthComplete_ErrorMessages(t *testing.T) {
c := &Context{ c := &Context{
App: th.App, App: th.App,
AppContext: th.Context, AppContext: th.Context,
Logger: th.TestLogger,
Params: &Params{ Params: &Params{
Service: "gitlab", Service: "gitlab",
}, },
} }
translationFunc := i18n.GetUserTranslations("en")
c.AppContext.SetT(translationFunc)
c.Logger = mlog.CreateConsoleTestLogger(t)
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.GitLabSettings.Enable = true }) th.App.UpdateConfig(func(cfg *model.Config) { *cfg.GitLabSettings.Enable = true })
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.EnableOAuthServiceProvider = true }) th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.EnableOAuthServiceProvider = true })
provider := &MattermostTestProvider{} provider := &MattermostTestProvider{}

View File

@ -178,11 +178,12 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) {
auditRec.AddMeta("obtained_user_id", user.Id) auditRec.AddMeta("obtained_user_id", user.Id)
c.LogAuditWithUserId(user.Id, "obtained user") c.LogAuditWithUserId(user.Id, "obtained user")
err = c.App.DoLogin(c.AppContext, w, r, user, "", isMobile, false, true) session, err := c.App.DoLogin(c.AppContext, w, r, user, "", isMobile, false, true)
if err != nil { if err != nil {
handleError(err) handleError(err)
return return
} }
c.AppContext = c.AppContext.WithSession(session)
auditRec.Success() auditRec.Success()
c.LogAuditWithUserId(user.Id, "success") c.LogAuditWithUserId(user.Id, "success")

View File

@ -33,7 +33,7 @@ var URL string
type TestHelper struct { type TestHelper struct {
App app.AppIface App app.AppIface
Context *request.Context Context request.CTX
Server *app.Server Server *app.Server
Web *Web Web *Web
@ -141,7 +141,6 @@ func setupTestHelper(tb testing.TB, includeCacheLayer bool, options []app.Option
IncludeCacheLayer: includeCacheLayer, IncludeCacheLayer: includeCacheLayer,
TestLogger: testLogger, TestLogger: testLogger,
} }
th.Context.SetLogger(testLogger)
return th return th
} }

View File

@ -138,10 +138,10 @@ func scheduleExportCmdF(command *cobra.Command, args []string) error {
defer cancel() defer cancel()
} }
c := request.EmptyContext(a.Log()) var rctx request.CTX = request.EmptyContext(a.Log())
c.SetContext(ctx) rctx = rctx.WithContext(ctx)
job, err := messageExportI.StartSynchronizeJob(c, startTime) job, err := messageExportI.StartSynchronizeJob(rctx, startTime)
if err != nil || job.Status == model.JobStatusError || job.Status == model.JobStatusCanceled { if err != nil || job.Status == model.JobStatusError || job.Status == model.JobStatusCanceled {
CommandPrintErrorln("ERROR: Message export job failed. Please check the server logs") CommandPrintErrorln("ERROR: Message export job failed. Please check the server logs")
} else { } else {

View File

@ -253,10 +253,10 @@ func (worker *BleveIndexerWorker) DoJob(job *model.Job) {
progress.TotalFilesCount = count progress.TotalFilesCount = count
} }
cancelContext := request.EmptyContext(worker.logger) var cancelContext request.CTX = request.EmptyContext(worker.logger)
cancelCtx, cancelCancelWatcher := context.WithCancel(context.Background()) cancelCtx, cancelCancelWatcher := context.WithCancel(context.Background())
cancelWatcherChan := make(chan struct{}, 1) cancelWatcherChan := make(chan struct{}, 1)
cancelContext.SetContext(cancelCtx) cancelContext = cancelContext.WithContext(cancelCtx)
go worker.jobServer.CancellationWatcher(cancelContext, job.Id, cancelWatcherChan) go worker.jobServer.CancellationWatcher(cancelContext, job.Id, cancelWatcherChan)
defer cancelCancelWatcher() defer cancelCancelWatcher()

View File

@ -22,8 +22,7 @@ type Context struct {
userAgent string userAgent string
acceptLanguage string acceptLanguage string
logger mlog.LoggerIFace logger mlog.LoggerIFace
context context.Context
context context.Context
} }
func NewContext(ctx context.Context, requestId, ipAddress, xForwardedFor, path, userAgent, acceptLanguage string, t i18n.TranslateFunc) *Context { func NewContext(ctx context.Context, requestId, ipAddress, xForwardedFor, path, userAgent, acceptLanguage string, t i18n.TranslateFunc) *Context {
@ -54,8 +53,8 @@ func TestContext(t testing.TB) *Context {
return EmptyContext(logger) return EmptyContext(logger)
} }
// Clone creates a shallow copy of Context, allowing clones to apply per-request changes. // clone creates a shallow copy of Context, allowing clones to apply per-request changes.
func (c *Context) Clone() CTX { func (c *Context) clone() *Context {
cCopy := *c cCopy := *c
return &cCopy return &cCopy
} }
@ -63,6 +62,9 @@ func (c *Context) Clone() CTX {
func (c *Context) T(translationID string, args ...any) string { func (c *Context) T(translationID string, args ...any) string {
return c.t(translationID, args...) return c.t(translationID, args...)
} }
func (c *Context) GetT() i18n.TranslateFunc {
return c.t
}
func (c *Context) Session() *model.Session { func (c *Context) Session() *model.Session {
return &c.session return &c.session
} }
@ -84,55 +86,71 @@ func (c *Context) UserAgent() string {
func (c *Context) AcceptLanguage() string { func (c *Context) AcceptLanguage() string {
return c.acceptLanguage return c.acceptLanguage
} }
func (c *Context) Logger() mlog.LoggerIFace {
return c.logger
}
func (c *Context) Context() context.Context { func (c *Context) Context() context.Context {
return c.context return c.context
} }
func (c *Context) SetSession(s *model.Session) { func (c *Context) WithT(t i18n.TranslateFunc) CTX {
c.session = *s rctx := c.clone()
rctx.t = t
return rctx
}
func (c *Context) WithSession(s *model.Session) CTX {
rctx := c.clone()
rctx.session = *s
return rctx
}
func (c *Context) WithRequestId(s string) CTX {
rctx := c.clone()
rctx.requestId = s
return rctx
}
func (c *Context) WithIPAddress(s string) CTX {
rctx := c.clone()
rctx.ipAddress = s
return rctx
}
func (c *Context) WithXForwardedFor(s string) CTX {
rctx := c.clone()
rctx.xForwardedFor = s
return rctx
}
func (c *Context) WithPath(s string) CTX {
rctx := c.clone()
rctx.path = s
return rctx
}
func (c *Context) WithUserAgent(s string) CTX {
rctx := c.clone()
rctx.userAgent = s
return rctx
}
func (c *Context) WithAcceptLanguage(s string) CTX {
rctx := c.clone()
rctx.acceptLanguage = s
return rctx
}
func (c *Context) WithContext(ctx context.Context) CTX {
rctx := c.clone()
rctx.context = ctx
return rctx
}
func (c *Context) WithLogger(logger mlog.LoggerIFace) CTX {
rctx := c.clone()
rctx.logger = logger
return rctx
} }
func (c *Context) SetT(t i18n.TranslateFunc) { func (c *Context) With(f func(ctx CTX) CTX) CTX {
c.t = t return f(c)
}
func (c *Context) SetRequestId(s string) {
c.requestId = s
}
func (c *Context) SetIPAddress(s string) {
c.ipAddress = s
}
func (c *Context) SetXForwardedFor(s string) {
c.xForwardedFor = s
}
func (c *Context) SetUserAgent(s string) {
c.userAgent = s
}
func (c *Context) SetAcceptLanguage(s string) {
c.acceptLanguage = s
}
func (c *Context) SetPath(s string) {
c.path = s
}
func (c *Context) SetContext(ctx context.Context) {
c.context = ctx
}
func (c *Context) GetT() i18n.TranslateFunc {
return c.t
}
func (c *Context) SetLogger(logger mlog.LoggerIFace) {
c.logger = logger
}
func (c *Context) Logger() mlog.LoggerIFace {
return c.logger
} }
type CTX interface { type CTX interface {
Clone() CTX
T(string, ...interface{}) string T(string, ...interface{}) string
GetT() i18n.TranslateFunc
Session() *model.Session Session() *model.Session
RequestId() string RequestId() string
IPAddress() string IPAddress() string
@ -140,16 +158,17 @@ type CTX interface {
Path() string Path() string
UserAgent() string UserAgent() string
AcceptLanguage() string AcceptLanguage() string
Context() context.Context
SetSession(s *model.Session)
SetT(i18n.TranslateFunc)
SetRequestId(string)
SetIPAddress(string)
SetUserAgent(string)
SetAcceptLanguage(string)
SetPath(string)
SetContext(ctx context.Context)
GetT() i18n.TranslateFunc
SetLogger(mlog.LoggerIFace)
Logger() mlog.LoggerIFace Logger() mlog.LoggerIFace
Context() context.Context
WithT(i18n.TranslateFunc) CTX
WithSession(s *model.Session) CTX
WithRequestId(string) CTX
WithIPAddress(string) CTX
WithXForwardedFor(string) CTX
WithPath(string) CTX
WithUserAgent(string) CTX
WithAcceptLanguage(string) CTX
WithLogger(mlog.LoggerIFace) CTX
WithContext(ctx context.Context) CTX
With(func(ctx CTX) CTX) CTX
} }