mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
Fix racy test issues (#24971)
This commit is contained in:
parent
366d1613b7
commit
486e836b83
@ -140,13 +140,12 @@ func setupTestHelper(dbStore store.Store, searchEngine *searchengine.Broker, ent
|
||||
th := &TestHelper{
|
||||
App: app.New(app.ServerConnector(s.Channels())),
|
||||
Server: s,
|
||||
Context: request.EmptyContext(testLogger),
|
||||
ConfigStore: configStore,
|
||||
IncludeCacheLayer: includeCache,
|
||||
Context: request.EmptyContext(testLogger),
|
||||
TestLogger: testLogger,
|
||||
LogBuffer: buffer,
|
||||
}
|
||||
th.Context.SetLogger(testLogger)
|
||||
|
||||
if s.Platform().SearchEngine != nil && s.Platform().SearchEngine.BleveEngine != nil && searchEngine != nil {
|
||||
searchEngine.BleveEngine = s.Platform().SearchEngine.BleveEngine
|
||||
|
@ -4,7 +4,6 @@
|
||||
package api4
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@ -1460,9 +1459,8 @@ func getChannelMember(c *Context, w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.AppContext
|
||||
ctx.SetContext(app.WithMaster(ctx.Context()))
|
||||
member, err := c.App.GetChannelMember(ctx, c.Params.ChannelId, c.Params.UserId)
|
||||
c.AppContext = c.AppContext.With(app.RequestContextWithMaster)
|
||||
member, err := c.App.GetChannelMember(c.AppContext, c.Params.ChannelId, c.Params.UserId)
|
||||
if err != nil {
|
||||
c.Err = err
|
||||
return
|
||||
@ -2004,7 +2002,7 @@ func channelMemberCountsByGroup(c *Context, w http.ResponseWriter, r *http.Reque
|
||||
|
||||
includeTimezones := r.URL.Query().Get("include_timezones") == "true"
|
||||
|
||||
channelMemberCounts, appErr := c.App.GetMemberCountsByGroup(app.WithMaster(context.Background()), c.Params.ChannelId, includeTimezones)
|
||||
channelMemberCounts, appErr := c.App.GetMemberCountsByGroup(c.AppContext.With(app.RequestContextWithMaster), c.Params.ChannelId, includeTimezones)
|
||||
if appErr != nil {
|
||||
c.Err = appErr
|
||||
return
|
||||
|
@ -195,7 +195,7 @@ func uploadRemoteData(c *Context, w http.ResponseWriter, r *http.Request) {
|
||||
defer c.LogAuditRec(auditRec)
|
||||
audit.AddEventParameter(auditRec, "upload_id", c.Params.UploadId)
|
||||
|
||||
c.AppContext.SetContext(app.WithMaster(c.AppContext.Context()))
|
||||
c.AppContext = c.AppContext.With(app.RequestContextWithMaster)
|
||||
us, err := c.App.GetUploadSession(c.AppContext, c.Params.UploadId)
|
||||
if err != nil {
|
||||
c.Err = err
|
||||
|
@ -123,7 +123,7 @@ func uploadData(c *Context, w http.ResponseWriter, r *http.Request) {
|
||||
defer c.LogAuditRec(auditRec)
|
||||
audit.AddEventParameter(auditRec, "upload_id", c.Params.UploadId)
|
||||
|
||||
c.AppContext.SetContext(app.WithMaster(c.AppContext.Context()))
|
||||
c.AppContext = c.AppContext.With(app.RequestContextWithMaster)
|
||||
us, err := c.App.GetUploadSession(c.AppContext, c.Params.UploadId)
|
||||
if err != nil {
|
||||
c.Err = err
|
||||
|
@ -1953,11 +1953,12 @@ func login(c *Context, w http.ResponseWriter, r *http.Request) {
|
||||
c.LogAuditWithUserId(user.Id, "authenticated")
|
||||
|
||||
isMobileDevice := utils.IsMobileRequest(r)
|
||||
err = c.App.DoLogin(c.AppContext, w, r, user, deviceId, isMobileDevice, false, false)
|
||||
session, err := c.App.DoLogin(c.AppContext, w, r, user, deviceId, isMobileDevice, false, false)
|
||||
if err != nil {
|
||||
c.Err = err
|
||||
return
|
||||
}
|
||||
c.AppContext = c.AppContext.WithSession(session)
|
||||
|
||||
c.LogAuditWithUserId(user.Id, "success")
|
||||
|
||||
@ -1995,11 +1996,12 @@ func loginWithDesktopToken(c *Context, w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = c.App.DoLogin(c.AppContext, w, r, user, deviceId, false, false, false)
|
||||
session, err := c.App.DoLogin(c.AppContext, w, r, user, deviceId, false, false, false)
|
||||
if err != nil {
|
||||
c.Err = err
|
||||
return
|
||||
}
|
||||
c.AppContext = c.AppContext.WithSession(session)
|
||||
|
||||
c.App.AttachSessionCookies(c.AppContext, w, r)
|
||||
|
||||
@ -2048,12 +2050,13 @@ func loginCWS(c *Context, w http.ResponseWriter, r *http.Request) {
|
||||
audit.AddEventParameterAuditable(auditRec, "user", user)
|
||||
c.LogAuditWithUserId(user.Id, "authenticated")
|
||||
isMobileDevice := utils.IsMobileRequest(r)
|
||||
err = c.App.DoLogin(c.AppContext, w, r, user, "", isMobileDevice, false, false)
|
||||
session, err := c.App.DoLogin(c.AppContext, w, r, user, "", isMobileDevice, false, false)
|
||||
if err != nil {
|
||||
c.LogErrorByCode(err)
|
||||
http.Redirect(w, r, *c.App.Config().ServiceSettings.SiteURL, http.StatusFound)
|
||||
return
|
||||
}
|
||||
c.AppContext = c.AppContext.WithSession(session)
|
||||
c.LogAuditWithUserId(user.Id, "success")
|
||||
c.App.AttachSessionCookies(c.AppContext, w, r)
|
||||
|
||||
|
@ -553,7 +553,7 @@ type AppIface interface {
|
||||
DoEmojisPermissionsMigration()
|
||||
DoGuestRolesCreationMigration()
|
||||
DoLocalRequest(c request.CTX, rawURL string, body []byte) (*http.Response, *model.AppError)
|
||||
DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile, isOAuthUser, isSaml bool) *model.AppError
|
||||
DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile, isOAuthUser, isSaml bool) (*model.Session, *model.AppError)
|
||||
DoPostAction(c request.CTX, postID, actionId, userID, selectedOption string) (string, *model.AppError)
|
||||
DoPostActionWithCookie(c request.CTX, postID, actionId, userID, selectedOption string, cookie *model.PostActionCookie) (string, *model.AppError)
|
||||
DoSystemConsoleRolesCreationMigration()
|
||||
@ -687,7 +687,7 @@ type AppIface interface {
|
||||
GetLatestVersion(latestVersionUrl string) (*model.GithubReleaseInfo, *model.AppError)
|
||||
GetLogs(c request.CTX, page, perPage int) ([]string, *model.AppError)
|
||||
GetLogsSkipSend(page, perPage int, logFilter *model.LogFilter) ([]string, *model.AppError)
|
||||
GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError)
|
||||
GetMemberCountsByGroup(rctx request.CTX, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError)
|
||||
GetMessageForNotification(post *model.Post, translateFunc i18n.TranslateFunc) string
|
||||
GetMultipleEmojiByName(c request.CTX, names []string) ([]*model.Emoji, *model.AppError)
|
||||
GetNewUsersForTeamPage(teamID string, page, perPage int, asAdmin bool, viewRestrictions *model.ViewUsersRestrictions) ([]*model.User, *model.AppError)
|
||||
|
@ -1666,7 +1666,9 @@ func (a *App) AddChannelMember(c request.CTX, userID string, channel *model.Chan
|
||||
}
|
||||
} else {
|
||||
a.Srv().Go(func() {
|
||||
a.PostAddToChannelMessage(c, userRequestor, user, channel, opts.PostRootID)
|
||||
if err := a.PostAddToChannelMessage(c, userRequestor, user, channel, opts.PostRootID); err != nil {
|
||||
c.Logger().Error("Failed to post AddToChannel message", mlog.Err(err))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -2363,7 +2365,9 @@ func (a *App) LeaveChannel(c request.CTX, channelID string, userID string) *mode
|
||||
}
|
||||
|
||||
a.Srv().Go(func() {
|
||||
a.postLeaveChannelMessage(c, user, channel)
|
||||
if err := a.postLeaveChannelMessage(c, user, channel); err != nil {
|
||||
c.Logger().Error("Failed to post LeaveChannel message", mlog.Err(err))
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
@ -3458,8 +3462,8 @@ func (a *App) ClearChannelMembersCache(c request.CTX, channelID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *App) GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) {
|
||||
channelMemberCounts, err := a.Srv().Store().Channel().GetMemberCountsByGroup(ctx, channelID, includeTimezones)
|
||||
func (a *App) GetMemberCountsByGroup(rctx request.CTX, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) {
|
||||
channelMemberCounts, err := a.Srv().Store().Channel().GetMemberCountsByGroup(rctx.Context(), channelID, includeTimezones)
|
||||
if err != nil {
|
||||
return nil, model.NewAppError("GetMemberCountsByGroup", "app.channel.get_member_count.app_error", nil, "", http.StatusInternalServerError).Wrap(err)
|
||||
}
|
||||
|
@ -2152,7 +2152,7 @@ func TestGetMemberCountsByGroup(t *testing.T) {
|
||||
mockChannelStore.On("GetMemberCountsByGroup", context.Background(), "channelID", true).Return(cmc, nil)
|
||||
mockStore.On("Channel").Return(&mockChannelStore)
|
||||
mockStore.On("GetDBSchemaVersion").Return(1, nil)
|
||||
resp, err := th.App.GetMemberCountsByGroup(context.Background(), "channelID", true)
|
||||
resp, err := th.App.GetMemberCountsByGroup(th.Context, "channelID", true)
|
||||
require.Nil(t, err)
|
||||
require.ElementsMatch(t, cmc, resp)
|
||||
}
|
||||
|
@ -34,7 +34,7 @@ func (a *App) SaveComplianceReport(rctx request.CTX, job *model.Compliance) (*mo
|
||||
|
||||
job.Type = model.ComplianceTypeAdhoc
|
||||
|
||||
rctx.SetLogger(rctx.Logger().With(job.LoggerFields()...))
|
||||
rctx = rctx.WithLogger(rctx.Logger().With(job.LoggerFields()...))
|
||||
|
||||
job, err := a.Srv().Store().Compliance().Save(job)
|
||||
if err != nil {
|
||||
@ -48,11 +48,10 @@ func (a *App) SaveComplianceReport(rctx request.CTX, job *model.Compliance) (*mo
|
||||
}
|
||||
|
||||
jCopy := job.DeepCopy()
|
||||
crctx := rctx.Clone()
|
||||
a.Srv().Go(func() {
|
||||
err := a.Compliance().RunComplianceJob(crctx, jCopy)
|
||||
err := a.Compliance().RunComplianceJob(rctx, jCopy)
|
||||
if err != nil {
|
||||
crctx.Logger().Warn("Error running compliance job", mlog.Err(err))
|
||||
rctx.Logger().Warn("Error running compliance job", mlog.Err(err))
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -4,16 +4,14 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mattermost/mattermost/server/public/plugin"
|
||||
"github.com/mattermost/mattermost/server/public/shared/request"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/store/sqlstore"
|
||||
)
|
||||
|
||||
// WithMaster adds the context value that master DB should be selected for this request.
|
||||
func WithMaster(ctx context.Context) context.Context {
|
||||
return sqlstore.WithMaster(ctx)
|
||||
// RequestContextWithMaster adds the context value that master DB should be selected for this request.
|
||||
func RequestContextWithMaster(c request.CTX) request.CTX {
|
||||
return sqlstore.RequestContextWithMaster(c)
|
||||
}
|
||||
|
||||
func pluginContext(c request.CTX) *plugin.Context {
|
||||
|
@ -796,11 +796,10 @@ func (a *App) UploadFileX(c request.CTX, channelID, name string, input io.Reader
|
||||
|
||||
if *a.Config().FileSettings.ExtractContent {
|
||||
infoCopy := *t.fileinfo
|
||||
crctx := c.Clone()
|
||||
a.Srv().GoBuffered(func() {
|
||||
err := a.ExtractContentFromFileInfo(crctx, &infoCopy)
|
||||
err := a.ExtractContentFromFileInfo(c, &infoCopy)
|
||||
if err != nil {
|
||||
crctx.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id))
|
||||
c.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id))
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -1048,11 +1047,10 @@ func (a *App) DoUploadFileExpectModification(c request.CTX, now time.Time, rawTe
|
||||
|
||||
if *a.Config().FileSettings.ExtractContent {
|
||||
infoCopy := *info
|
||||
crctx := c.Clone()
|
||||
a.Srv().GoBuffered(func() {
|
||||
err := a.ExtractContentFromFileInfo(crctx, &infoCopy)
|
||||
err := a.ExtractContentFromFileInfo(c, &infoCopy)
|
||||
if err != nil {
|
||||
crctx.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id))
|
||||
c.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -99,7 +99,6 @@ func setupTestHelper(dbStore store.Store, enterprise bool, includeCacheLayer boo
|
||||
IncludeCacheLayer: includeCacheLayer,
|
||||
ConfigStore: configStore,
|
||||
}
|
||||
th.Context.SetLogger(testLogger)
|
||||
|
||||
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.TeamSettings.MaxUsersPerTeam = 50 })
|
||||
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.RateLimitSettings.Enable = false })
|
||||
|
@ -268,7 +268,7 @@ func (a *App) bulkImport(c request.CTX, jsonlReader io.Reader, attachmentsReader
|
||||
linesChan = make(chan imports.LineImportWorkerData, workers)
|
||||
for i := 0; i < workers; i++ {
|
||||
wg.Add(1)
|
||||
go a.bulkImportWorker(c.Clone(), dryRun, &wg, linesChan, errorsChan)
|
||||
go a.bulkImportWorker(c, dryRun, &wg, linesChan, errorsChan)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -156,7 +156,7 @@ func (a *App) GetUserForLogin(c request.CTX, id, loginId string) (*model.User, *
|
||||
return nil, model.NewAppError("GetUserForLogin", "store.sql_user.get_for_login.app_error", nil, "", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile, isOAuthUser, isSaml bool) *model.AppError {
|
||||
func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile, isOAuthUser, isSaml bool) (*model.Session, *model.AppError) {
|
||||
var rejectionReason string
|
||||
pluginContext := pluginContext(c)
|
||||
a.ch.RunMultiHook(func(hooks plugin.Hooks) bool {
|
||||
@ -165,7 +165,7 @@ func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, use
|
||||
}, plugin.UserWillLogInID)
|
||||
|
||||
if rejectionReason != "" {
|
||||
return model.NewAppError("DoLogin", "Login rejected by plugin: "+rejectionReason, nil, "", http.StatusBadRequest)
|
||||
return nil, model.NewAppError("DoLogin", "Login rejected by plugin: "+rejectionReason, nil, "", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
session := &model.Session{UserId: user.Id, Roles: user.GetRawRoles(), DeviceId: deviceID, IsOAuth: false, Props: map[string]string{
|
||||
@ -181,7 +181,7 @@ func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, use
|
||||
// A special case where we logout of all other sessions with the same Id
|
||||
if err := a.RevokeSessionsForDeviceId(c, user.Id, deviceID, ""); err != nil {
|
||||
err.StatusCode = http.StatusInternalServerError
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
} else if isMobile {
|
||||
a.ch.srv.platform.SetSessionExpireInHours(session, *a.Config().ServiceSettings.SessionLengthMobileInHours)
|
||||
@ -210,12 +210,12 @@ func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, use
|
||||
var err *model.AppError
|
||||
if session, err = a.CreateSession(c, session); err != nil {
|
||||
err.StatusCode = http.StatusInternalServerError
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
w.Header().Set(model.HeaderToken, session.Token)
|
||||
|
||||
c.SetSession(session)
|
||||
c = c.WithSession(session)
|
||||
if a.Srv().License() != nil && *a.Srv().License().Features.LDAP && a.Ldap() != nil {
|
||||
userVal := *user
|
||||
sessionVal := *session
|
||||
@ -231,7 +231,7 @@ func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, use
|
||||
}, plugin.UserHasLoggedInID)
|
||||
})
|
||||
|
||||
return nil
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (a *App) AttachCloudSessionCookie(c request.CTX, w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -3828,7 +3828,7 @@ func (a *OpenTracingAppLayer) DoLocalRequest(c request.CTX, rawURL string, body
|
||||
return resultVar0, resultVar1
|
||||
}
|
||||
|
||||
func (a *OpenTracingAppLayer) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile bool, isOAuthUser bool, isSaml bool) *model.AppError {
|
||||
func (a *OpenTracingAppLayer) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile bool, isOAuthUser bool, isSaml bool) (*model.Session, *model.AppError) {
|
||||
origCtx := a.ctx
|
||||
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.DoLogin")
|
||||
|
||||
@ -3840,14 +3840,14 @@ func (a *OpenTracingAppLayer) DoLogin(c request.CTX, w http.ResponseWriter, r *h
|
||||
}()
|
||||
|
||||
defer span.Finish()
|
||||
resultVar0 := a.app.DoLogin(c, w, r, user, deviceID, isMobile, isOAuthUser, isSaml)
|
||||
resultVar0, resultVar1 := a.app.DoLogin(c, w, r, user, deviceID, isMobile, isOAuthUser, isSaml)
|
||||
|
||||
if resultVar0 != nil {
|
||||
span.LogFields(spanlog.Error(resultVar0))
|
||||
if resultVar1 != nil {
|
||||
span.LogFields(spanlog.Error(resultVar1))
|
||||
ext.Error.Set(span, true)
|
||||
}
|
||||
|
||||
return resultVar0
|
||||
return resultVar0, resultVar1
|
||||
}
|
||||
|
||||
func (a *OpenTracingAppLayer) DoPermissionsMigrations() error {
|
||||
@ -7320,7 +7320,7 @@ func (a *OpenTracingAppLayer) GetMarketplacePlugins(filter *model.MarketplacePlu
|
||||
return resultVar0, resultVar1
|
||||
}
|
||||
|
||||
func (a *OpenTracingAppLayer) GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) {
|
||||
func (a *OpenTracingAppLayer) GetMemberCountsByGroup(rctx request.CTX, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) {
|
||||
origCtx := a.ctx
|
||||
span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetMemberCountsByGroup")
|
||||
|
||||
@ -7332,7 +7332,7 @@ func (a *OpenTracingAppLayer) GetMemberCountsByGroup(ctx context.Context, channe
|
||||
}()
|
||||
|
||||
defer span.Finish()
|
||||
resultVar0, resultVar1 := a.app.GetMemberCountsByGroup(ctx, channelID, includeTimezones)
|
||||
resultVar0, resultVar1 := a.app.GetMemberCountsByGroup(rctx, channelID, includeTimezones)
|
||||
|
||||
if resultVar1 != nil {
|
||||
span.LogFields(spanlog.Error(resultVar1))
|
||||
|
@ -1263,7 +1263,7 @@ func (api *PluginAPI) UploadData(us *model.UploadSession, rd io.Reader) (*model.
|
||||
|
||||
func (api *PluginAPI) GetUploadSession(uploadID string) (*model.UploadSession, error) {
|
||||
// We want to fetch from master DB to avoid a potential read-after-write on the plugin side.
|
||||
api.ctx.SetContext(WithMaster(api.ctx.Context()))
|
||||
api.ctx = api.ctx.With(RequestContextWithMaster)
|
||||
fi, err := api.app.GetUploadSession(api.ctx, uploadID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -727,9 +727,10 @@ func TestUserWillLogIn_Blocked(t *testing.T) {
|
||||
|
||||
r := &http.Request{}
|
||||
w := httptest.NewRecorder()
|
||||
err = th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false)
|
||||
session, err := th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false)
|
||||
|
||||
assert.Contains(t, err.Id, "Login rejected by plugin", "Expected Login rejected by plugin, got %s", err.Id)
|
||||
assert.Nil(t, session)
|
||||
}
|
||||
|
||||
func TestUserWillLogInIn_Passed(t *testing.T) {
|
||||
@ -766,10 +767,11 @@ func TestUserWillLogInIn_Passed(t *testing.T) {
|
||||
|
||||
r := &http.Request{}
|
||||
w := httptest.NewRecorder()
|
||||
err = th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false)
|
||||
session, err := th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false)
|
||||
|
||||
assert.Nil(t, err, "Expected nil, got %s", err)
|
||||
assert.Equal(t, th.Context.Session().UserId, th.BasicUser.Id)
|
||||
require.NotNil(t, session)
|
||||
assert.Equal(t, session.UserId, th.BasicUser.Id)
|
||||
}
|
||||
|
||||
func TestUserHasLoggedIn(t *testing.T) {
|
||||
@ -807,9 +809,10 @@ func TestUserHasLoggedIn(t *testing.T) {
|
||||
|
||||
r := &http.Request{}
|
||||
w := httptest.NewRecorder()
|
||||
err = th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false)
|
||||
session, err := th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false)
|
||||
|
||||
assert.Nil(t, err, "Expected nil, got %s", err)
|
||||
assert.NotNil(t, session)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
|
@ -532,7 +532,7 @@ func (a *App) handlePostEvents(c request.CTX, post *model.Post, user *model.User
|
||||
a.Srv().Go(func() {
|
||||
_, err := a.SendAutoResponseIfNecessary(c, channel, user, post)
|
||||
if err != nil {
|
||||
mlog.Error("Failed to send auto response", mlog.String("user_id", user.Id), mlog.String("post_id", post.Id), mlog.Err(err))
|
||||
c.Logger().Error("Failed to send auto response", mlog.String("user_id", user.Id), mlog.String("post_id", post.Id), mlog.Err(err))
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -540,7 +540,7 @@ func (a *App) handlePostEvents(c request.CTX, post *model.Post, user *model.User
|
||||
if triggerWebhooks {
|
||||
a.Srv().Go(func() {
|
||||
if err := a.handleWebhookEvents(c, post, team, channel, user); err != nil {
|
||||
mlog.Error(err.Error())
|
||||
c.Logger().Error("Failed to handle webhook event", mlog.Err(err))
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -1378,7 +1378,7 @@ func (a *App) DeletePost(c request.CTX, postID, deleteByID string) (*model.Post,
|
||||
|
||||
a.Srv().Go(func() {
|
||||
if err = a.RemoveNotifications(c, post, channel); err != nil {
|
||||
a.Log().Error("DeletePost failed to delete notification", mlog.Err(err))
|
||||
c.Logger().Error("DeletePost failed to delete notification", mlog.Err(err))
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/mattermost/mattermost/server/public/model"
|
||||
"github.com/mattermost/mattermost/server/public/shared/request"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/app/email"
|
||||
emailmocks "github.com/mattermost/mattermost/server/v8/channels/app/email/mocks"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/app/teams"
|
||||
@ -1027,7 +1028,12 @@ func TestLeaveTeamPanic(t *testing.T) {
|
||||
mockLicenseStore.On("Get", "").Return(&model.LicenseRecord{}, nil)
|
||||
|
||||
mockTeamStore := mocks.TeamStore{}
|
||||
mockTeamStore.On("GetMember", sqlstore.RequestContextWithMaster(th.Context), "myteam", "userID").Return(&model.TeamMember{TeamId: "myteam", UserId: "userID"}, nil)
|
||||
mockTeamStore.On("GetMember", mock.AnythingOfType("*request.Context"), "myteam", "userID").Return(&model.TeamMember{TeamId: "myteam", UserId: "userID"}, nil).Run(func(args mock.Arguments) {
|
||||
c, ok := args[0].(request.CTX)
|
||||
require.True(t, ok)
|
||||
|
||||
sqlstore.HasMaster(c.Context())
|
||||
})
|
||||
mockTeamStore.On("UpdateMember", mock.Anything).Return(nil, errors.New("repro error")) // This is the line that triggers the error
|
||||
|
||||
mockStore.On("Channel").Return(&mockChannelStore)
|
||||
|
@ -203,7 +203,7 @@ func (a *App) UploadData(c request.CTX, us *model.UploadSession, rd io.Reader) (
|
||||
}()
|
||||
|
||||
// fetch the session from store to check for inconsistencies.
|
||||
c.SetContext(WithMaster(c.Context()))
|
||||
c = c.With(RequestContextWithMaster)
|
||||
if storedSession, err := a.GetUploadSession(c, us.Id); err != nil {
|
||||
return nil, err
|
||||
} else if us.FileOffset != storedSession.FileOffset {
|
||||
@ -318,11 +318,10 @@ func (a *App) UploadData(c request.CTX, us *model.UploadSession, rd io.Reader) (
|
||||
|
||||
if *a.Config().FileSettings.ExtractContent {
|
||||
infoCopy := *info
|
||||
crctx := c.Clone()
|
||||
a.Srv().Go(func() {
|
||||
err := a.ExtractContentFromFileInfo(crctx, &infoCopy)
|
||||
err := a.ExtractContentFromFileInfo(c, &infoCopy)
|
||||
if err != nil {
|
||||
crctx.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id))
|
||||
c.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ package jobs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mattermost/mattermost/server/public/model"
|
||||
@ -33,8 +34,9 @@ type BatchMigrationWorker struct {
|
||||
store store.Store
|
||||
app BatchMigrationWorkerAppIFace
|
||||
|
||||
stop chan bool
|
||||
stop chan struct{}
|
||||
stopped chan bool
|
||||
closed atomic.Bool
|
||||
jobs chan model.Job
|
||||
|
||||
migrationKey string
|
||||
@ -49,7 +51,7 @@ func MakeBatchMigrationWorker(jobServer *JobServer, store store.Store, app Batch
|
||||
logger: jobServer.Logger().With(mlog.String("worker_name", migrationKey)),
|
||||
store: store,
|
||||
app: app,
|
||||
stop: make(chan bool, 1),
|
||||
stop: make(chan struct{}),
|
||||
stopped: make(chan bool, 1),
|
||||
jobs: make(chan model.Job),
|
||||
migrationKey: migrationKey,
|
||||
@ -64,7 +66,9 @@ func (worker *BatchMigrationWorker) Run() {
|
||||
worker.logger.Debug("Worker started")
|
||||
// We have to re-assign the stop channel again, because
|
||||
// it might happen that the job was restarted due to a config change.
|
||||
worker.stop = make(chan bool, 1)
|
||||
if worker.closed.CompareAndSwap(true, false) {
|
||||
worker.stop = make(chan struct{})
|
||||
}
|
||||
|
||||
defer func() {
|
||||
worker.logger.Debug("Worker finished")
|
||||
@ -84,6 +88,11 @@ func (worker *BatchMigrationWorker) Run() {
|
||||
|
||||
// Stop interrupts the worker even if the migration has not yet completed.
|
||||
func (worker *BatchMigrationWorker) Stop() {
|
||||
// Set to close, and if already closed before, then return.
|
||||
if !worker.closed.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
|
||||
worker.logger.Debug("Worker stopping")
|
||||
close(worker.stop)
|
||||
<-worker.stopped
|
||||
|
@ -88,7 +88,6 @@ func setupTestHelper(dbStore store.Store, enterprise bool, includeCacheLayer boo
|
||||
IncludeCacheLayer: includeCacheLayer,
|
||||
ConfigStore: configStore,
|
||||
}
|
||||
th.Context.SetLogger(testLogger)
|
||||
|
||||
prevListenAddress := *th.App.Config().ServiceSettings.ListenAddress
|
||||
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.ListenAddress = "localhost:0" })
|
||||
|
@ -100,10 +100,10 @@ func (worker *Worker) DoJob(job *model.Job) {
|
||||
return
|
||||
}
|
||||
|
||||
cancelContext := request.EmptyContext(worker.logger)
|
||||
var cancelContext request.CTX = request.EmptyContext(worker.logger)
|
||||
cancelCtx, cancelCancelWatcher := context.WithCancel(context.Background())
|
||||
cancelWatcherChan := make(chan struct{}, 1)
|
||||
cancelContext.SetContext(cancelCtx)
|
||||
cancelContext = cancelContext.WithContext(cancelCtx)
|
||||
go worker.jobServer.CancellationWatcher(cancelContext, job.Id, cancelWatcherChan)
|
||||
defer cancelCancelWatcher()
|
||||
|
||||
|
@ -29,7 +29,7 @@ type S3PathMigrationWorker struct {
|
||||
store store.Store
|
||||
fileBackend *filestore.S3FileBackend
|
||||
|
||||
stop chan bool
|
||||
stop chan struct{}
|
||||
stopped chan bool
|
||||
jobs chan model.Job
|
||||
}
|
||||
@ -45,7 +45,7 @@ func MakeWorker(jobServer *jobs.JobServer, store store.Store, fileBackend filest
|
||||
logger: jobServer.Logger().With(mlog.String("worker_name", workerName)),
|
||||
store: store,
|
||||
fileBackend: s3Backend,
|
||||
stop: make(chan bool, 1),
|
||||
stop: make(chan struct{}),
|
||||
stopped: make(chan bool, 1),
|
||||
jobs: make(chan model.Job),
|
||||
}
|
||||
@ -56,7 +56,7 @@ func (worker *S3PathMigrationWorker) Run() {
|
||||
worker.logger.Debug("Worker started")
|
||||
// We have to re-assign the stop channel again, because
|
||||
// it might happen that the job was restarted due to a config change.
|
||||
worker.stop = make(chan bool, 1)
|
||||
worker.stop = make(chan struct{}, 1)
|
||||
|
||||
defer func() {
|
||||
worker.logger.Debug("Worker finished")
|
||||
|
@ -31,7 +31,7 @@ func WithMaster(ctx context.Context) context.Context {
|
||||
// RequestContextWithMaster adds the context value that master DB should be selected for this request.
|
||||
func RequestContextWithMaster(c request.CTX) request.CTX {
|
||||
ctx := WithMaster(c.Context())
|
||||
c.SetContext(ctx)
|
||||
c = c.WithContext(ctx)
|
||||
return c
|
||||
}
|
||||
|
||||
|
@ -26,39 +26,30 @@ func TestRequestContextWithMaster(t *testing.T) {
|
||||
assert.True(t, HasMaster(rctx.Context()))
|
||||
})
|
||||
|
||||
t.Run("directly assigning does cause the child to alter the parent", func(t *testing.T) {
|
||||
var rctx request.CTX = request.TestContext(t)
|
||||
rctxClone := rctx
|
||||
rctxClone = RequestContextWithMaster(rctxClone)
|
||||
|
||||
assert.True(t, HasMaster(rctx.Context()))
|
||||
assert.True(t, HasMaster(rctxClone.Context()))
|
||||
})
|
||||
|
||||
t.Run("values get copied from parent", func(t *testing.T) {
|
||||
t.Run("values get copied from original context", func(t *testing.T) {
|
||||
var rctx request.CTX = request.TestContext(t)
|
||||
rctx = RequestContextWithMaster(rctx)
|
||||
rctxClone := rctx.Clone()
|
||||
rctxCopy := rctx
|
||||
|
||||
assert.True(t, HasMaster(rctx.Context()))
|
||||
assert.True(t, HasMaster(rctxClone.Context()))
|
||||
assert.True(t, HasMaster(rctxCopy.Context()))
|
||||
})
|
||||
|
||||
t.Run("changing the child does not alter the parent", func(t *testing.T) {
|
||||
t.Run("directly assigning does not cause the copy to alter the original context", func(t *testing.T) {
|
||||
var rctx request.CTX = request.TestContext(t)
|
||||
rctxClone := rctx.Clone()
|
||||
rctxClone = RequestContextWithMaster(rctxClone)
|
||||
rctxCopy := rctx
|
||||
rctxCopy = RequestContextWithMaster(rctxCopy)
|
||||
|
||||
assert.False(t, HasMaster(rctx.Context()))
|
||||
assert.True(t, HasMaster(rctxClone.Context()))
|
||||
assert.True(t, HasMaster(rctxCopy.Context()))
|
||||
})
|
||||
|
||||
t.Run("changing the parent does not alter the child", func(t *testing.T) {
|
||||
t.Run("directly assigning does not cause the original context to alter the copy", func(t *testing.T) {
|
||||
var rctx request.CTX = request.TestContext(t)
|
||||
rctxClone := rctx.Clone()
|
||||
rctxCopy := rctx
|
||||
rctx = RequestContextWithMaster(rctx)
|
||||
|
||||
assert.True(t, HasMaster(rctx.Context()))
|
||||
assert.False(t, HasMaster(rctxClone.Context()))
|
||||
assert.False(t, HasMaster(rctxCopy.Context()))
|
||||
})
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ import (
|
||||
|
||||
type Context struct {
|
||||
App app.AppIface
|
||||
AppContext *request.Context
|
||||
AppContext request.CTX
|
||||
Logger *mlog.Logger
|
||||
Params *Params
|
||||
Err *model.AppError
|
||||
|
@ -71,7 +71,7 @@ func TestMfaRequired(t *testing.T) {
|
||||
|
||||
th.App.Srv().SetLicense(model.NewTestLicense("mfa"))
|
||||
|
||||
th.Context.SetSession(&model.Session{Id: "abc", UserId: "userid"})
|
||||
th.Context = th.Context.WithSession(&model.Session{Id: "abc", UserId: "userid"})
|
||||
|
||||
th.App.UpdateConfig(func(cfg *model.Config) {
|
||||
*cfg.AnnouncementSettings.UserNoticesEnabled = false
|
||||
|
@ -205,7 +205,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
span.Finish()
|
||||
}()
|
||||
c.AppContext.SetContext(ctx)
|
||||
c.AppContext = c.AppContext.WithContext(ctx)
|
||||
|
||||
tmpSrv := *c.App.Srv()
|
||||
tmpSrv.SetStore(opentracinglayer.New(c.App.Srv().Store(), ctx))
|
||||
@ -285,7 +285,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
} else if !session.IsOAuth && tokenLocation == app.TokenLocationQueryString {
|
||||
c.Err = model.NewAppError("ServeHTTP", "api.context.token_provided.app_error", nil, "token="+token, http.StatusUnauthorized)
|
||||
} else {
|
||||
c.AppContext.SetSession(session)
|
||||
c.AppContext = c.AppContext.WithSession(session)
|
||||
}
|
||||
|
||||
// Rate limit by UserID
|
||||
@ -301,7 +301,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
c.Logger.Warn("Invalid CWS token", mlog.Err(err))
|
||||
c.Err = err
|
||||
} else {
|
||||
c.AppContext.SetSession(session)
|
||||
c.AppContext = c.AppContext.WithSession(session)
|
||||
}
|
||||
} else if token != "" && c.App.Channels().License() != nil && c.App.Channels().License().HasRemoteClusterService() && tokenLocation == app.TokenLocationRemoteClusterHeader {
|
||||
// Get the remote cluster
|
||||
@ -315,7 +315,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
c.Logger.Warn("Invalid remote cluster token", mlog.Err(err))
|
||||
c.Err = err
|
||||
} else {
|
||||
c.AppContext.SetSession(session)
|
||||
c.AppContext = c.AppContext.WithSession(session)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -327,7 +327,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
mlog.String("user_id", c.AppContext.Session().UserId),
|
||||
mlog.String("method", r.Method),
|
||||
)
|
||||
c.AppContext.SetLogger(c.Logger)
|
||||
c.AppContext = c.AppContext.WithLogger(c.Logger)
|
||||
|
||||
if c.Err == nil && h.RequireSession {
|
||||
c.SessionRequired()
|
||||
@ -354,7 +354,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// shape IP:PORT (it will be "@" in Linux, for example)
|
||||
isLocalOrigin := !strings.Contains(r.RemoteAddr, ":")
|
||||
if *c.App.Config().ServiceSettings.EnableLocalMode && isLocalOrigin {
|
||||
c.AppContext.SetSession(&model.Session{Local: true})
|
||||
c.AppContext = c.AppContext.WithSession(&model.Session{Local: true})
|
||||
} else if !isLocalOrigin {
|
||||
c.Err = model.NewAppError("", "api.context.local_origin_required.app_error", nil, "LocalOriginRequired", http.StatusUnauthorized)
|
||||
}
|
||||
@ -501,7 +501,7 @@ func (h *Handler) checkCSRFToken(c *Context, r *http.Request, token string, toke
|
||||
}
|
||||
|
||||
if !csrfCheckPassed {
|
||||
c.AppContext.SetSession(&model.Session{})
|
||||
c.AppContext = c.AppContext.WithSession(&model.Session{})
|
||||
c.Err = model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token+" Appears to be a CSRF attempt", http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
|
@ -321,13 +321,14 @@ func completeOAuth(c *Context, w http.ResponseWriter, r *http.Request) {
|
||||
} else if action == model.OAuthActionSSOToEmail {
|
||||
redirectURL = app.GetProtocol(r) + "://" + r.Host + "/claim?email=" + url.QueryEscape(props["email"])
|
||||
} else {
|
||||
err = c.App.DoLogin(c.AppContext, w, r, user, "", isMobile, false, false)
|
||||
session, err := c.App.DoLogin(c.AppContext, w, r, user, "", isMobile, false, false)
|
||||
if err != nil {
|
||||
err.Translate(c.AppContext.T)
|
||||
mlog.Error(err.Error())
|
||||
renderError(err)
|
||||
return
|
||||
}
|
||||
c.AppContext = c.AppContext.WithSession(session)
|
||||
|
||||
// Old mobile version
|
||||
if isMobile && !hasRedirectURL {
|
||||
|
@ -18,8 +18,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/mattermost/mattermost/server/public/model"
|
||||
"github.com/mattermost/mattermost/server/public/shared/i18n"
|
||||
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
||||
"github.com/mattermost/mattermost/server/public/shared/request"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/utils"
|
||||
"github.com/mattermost/mattermost/server/v8/einterfaces"
|
||||
@ -397,20 +395,18 @@ func TestMobileLoginWithOAuth(t *testing.T) {
|
||||
c := &Context{
|
||||
App: th.App,
|
||||
AppContext: th.Context,
|
||||
Logger: th.TestLogger,
|
||||
Params: &Params{
|
||||
Service: "gitlab",
|
||||
},
|
||||
}
|
||||
|
||||
var siteURL = "http://localhost:8065"
|
||||
siteURL := "http://localhost:8065"
|
||||
th.App.UpdateConfig(func(cfg *model.Config) {
|
||||
*cfg.ServiceSettings.SiteURL = siteURL
|
||||
*cfg.GitLabSettings.Enable = true
|
||||
})
|
||||
|
||||
translationFunc := i18n.GetUserTranslations("en")
|
||||
c.AppContext.SetT(translationFunc)
|
||||
c.Logger = th.TestLogger
|
||||
provider := &MattermostTestProvider{}
|
||||
einterfaces.RegisterOAuthProvider(model.ServiceGitlab, provider)
|
||||
|
||||
@ -617,14 +613,12 @@ func TestOAuthComplete_ErrorMessages(t *testing.T) {
|
||||
c := &Context{
|
||||
App: th.App,
|
||||
AppContext: th.Context,
|
||||
Logger: th.TestLogger,
|
||||
Params: &Params{
|
||||
Service: "gitlab",
|
||||
},
|
||||
}
|
||||
|
||||
translationFunc := i18n.GetUserTranslations("en")
|
||||
c.AppContext.SetT(translationFunc)
|
||||
c.Logger = mlog.CreateConsoleTestLogger(t)
|
||||
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.GitLabSettings.Enable = true })
|
||||
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.EnableOAuthServiceProvider = true })
|
||||
provider := &MattermostTestProvider{}
|
||||
|
@ -178,11 +178,12 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) {
|
||||
auditRec.AddMeta("obtained_user_id", user.Id)
|
||||
c.LogAuditWithUserId(user.Id, "obtained user")
|
||||
|
||||
err = c.App.DoLogin(c.AppContext, w, r, user, "", isMobile, false, true)
|
||||
session, err := c.App.DoLogin(c.AppContext, w, r, user, "", isMobile, false, true)
|
||||
if err != nil {
|
||||
handleError(err)
|
||||
return
|
||||
}
|
||||
c.AppContext = c.AppContext.WithSession(session)
|
||||
|
||||
auditRec.Success()
|
||||
c.LogAuditWithUserId(user.Id, "success")
|
||||
|
@ -33,7 +33,7 @@ var URL string
|
||||
|
||||
type TestHelper struct {
|
||||
App app.AppIface
|
||||
Context *request.Context
|
||||
Context request.CTX
|
||||
Server *app.Server
|
||||
Web *Web
|
||||
|
||||
@ -141,7 +141,6 @@ func setupTestHelper(tb testing.TB, includeCacheLayer bool, options []app.Option
|
||||
IncludeCacheLayer: includeCacheLayer,
|
||||
TestLogger: testLogger,
|
||||
}
|
||||
th.Context.SetLogger(testLogger)
|
||||
|
||||
return th
|
||||
}
|
||||
|
@ -138,10 +138,10 @@ func scheduleExportCmdF(command *cobra.Command, args []string) error {
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
c := request.EmptyContext(a.Log())
|
||||
c.SetContext(ctx)
|
||||
var rctx request.CTX = request.EmptyContext(a.Log())
|
||||
rctx = rctx.WithContext(ctx)
|
||||
|
||||
job, err := messageExportI.StartSynchronizeJob(c, startTime)
|
||||
job, err := messageExportI.StartSynchronizeJob(rctx, startTime)
|
||||
if err != nil || job.Status == model.JobStatusError || job.Status == model.JobStatusCanceled {
|
||||
CommandPrintErrorln("ERROR: Message export job failed. Please check the server logs")
|
||||
} else {
|
||||
|
@ -253,10 +253,10 @@ func (worker *BleveIndexerWorker) DoJob(job *model.Job) {
|
||||
progress.TotalFilesCount = count
|
||||
}
|
||||
|
||||
cancelContext := request.EmptyContext(worker.logger)
|
||||
var cancelContext request.CTX = request.EmptyContext(worker.logger)
|
||||
cancelCtx, cancelCancelWatcher := context.WithCancel(context.Background())
|
||||
cancelWatcherChan := make(chan struct{}, 1)
|
||||
cancelContext.SetContext(cancelCtx)
|
||||
cancelContext = cancelContext.WithContext(cancelCtx)
|
||||
go worker.jobServer.CancellationWatcher(cancelContext, job.Id, cancelWatcherChan)
|
||||
defer cancelCancelWatcher()
|
||||
|
||||
|
@ -22,8 +22,7 @@ type Context struct {
|
||||
userAgent string
|
||||
acceptLanguage string
|
||||
logger mlog.LoggerIFace
|
||||
|
||||
context context.Context
|
||||
context context.Context
|
||||
}
|
||||
|
||||
func NewContext(ctx context.Context, requestId, ipAddress, xForwardedFor, path, userAgent, acceptLanguage string, t i18n.TranslateFunc) *Context {
|
||||
@ -54,8 +53,8 @@ func TestContext(t testing.TB) *Context {
|
||||
return EmptyContext(logger)
|
||||
}
|
||||
|
||||
// Clone creates a shallow copy of Context, allowing clones to apply per-request changes.
|
||||
func (c *Context) Clone() CTX {
|
||||
// clone creates a shallow copy of Context, allowing clones to apply per-request changes.
|
||||
func (c *Context) clone() *Context {
|
||||
cCopy := *c
|
||||
return &cCopy
|
||||
}
|
||||
@ -63,6 +62,9 @@ func (c *Context) Clone() CTX {
|
||||
func (c *Context) T(translationID string, args ...any) string {
|
||||
return c.t(translationID, args...)
|
||||
}
|
||||
func (c *Context) GetT() i18n.TranslateFunc {
|
||||
return c.t
|
||||
}
|
||||
func (c *Context) Session() *model.Session {
|
||||
return &c.session
|
||||
}
|
||||
@ -84,55 +86,71 @@ func (c *Context) UserAgent() string {
|
||||
func (c *Context) AcceptLanguage() string {
|
||||
return c.acceptLanguage
|
||||
}
|
||||
|
||||
func (c *Context) Logger() mlog.LoggerIFace {
|
||||
return c.logger
|
||||
}
|
||||
func (c *Context) Context() context.Context {
|
||||
return c.context
|
||||
}
|
||||
|
||||
func (c *Context) SetSession(s *model.Session) {
|
||||
c.session = *s
|
||||
func (c *Context) WithT(t i18n.TranslateFunc) CTX {
|
||||
rctx := c.clone()
|
||||
rctx.t = t
|
||||
return rctx
|
||||
}
|
||||
func (c *Context) WithSession(s *model.Session) CTX {
|
||||
rctx := c.clone()
|
||||
rctx.session = *s
|
||||
return rctx
|
||||
}
|
||||
func (c *Context) WithRequestId(s string) CTX {
|
||||
rctx := c.clone()
|
||||
rctx.requestId = s
|
||||
return rctx
|
||||
}
|
||||
func (c *Context) WithIPAddress(s string) CTX {
|
||||
rctx := c.clone()
|
||||
rctx.ipAddress = s
|
||||
return rctx
|
||||
}
|
||||
func (c *Context) WithXForwardedFor(s string) CTX {
|
||||
rctx := c.clone()
|
||||
rctx.xForwardedFor = s
|
||||
return rctx
|
||||
}
|
||||
func (c *Context) WithPath(s string) CTX {
|
||||
rctx := c.clone()
|
||||
rctx.path = s
|
||||
return rctx
|
||||
}
|
||||
func (c *Context) WithUserAgent(s string) CTX {
|
||||
rctx := c.clone()
|
||||
rctx.userAgent = s
|
||||
return rctx
|
||||
}
|
||||
func (c *Context) WithAcceptLanguage(s string) CTX {
|
||||
rctx := c.clone()
|
||||
rctx.acceptLanguage = s
|
||||
return rctx
|
||||
}
|
||||
func (c *Context) WithContext(ctx context.Context) CTX {
|
||||
rctx := c.clone()
|
||||
rctx.context = ctx
|
||||
return rctx
|
||||
}
|
||||
func (c *Context) WithLogger(logger mlog.LoggerIFace) CTX {
|
||||
rctx := c.clone()
|
||||
rctx.logger = logger
|
||||
return rctx
|
||||
}
|
||||
|
||||
func (c *Context) SetT(t i18n.TranslateFunc) {
|
||||
c.t = t
|
||||
}
|
||||
func (c *Context) SetRequestId(s string) {
|
||||
c.requestId = s
|
||||
}
|
||||
func (c *Context) SetIPAddress(s string) {
|
||||
c.ipAddress = s
|
||||
}
|
||||
func (c *Context) SetXForwardedFor(s string) {
|
||||
c.xForwardedFor = s
|
||||
}
|
||||
func (c *Context) SetUserAgent(s string) {
|
||||
c.userAgent = s
|
||||
}
|
||||
func (c *Context) SetAcceptLanguage(s string) {
|
||||
c.acceptLanguage = s
|
||||
}
|
||||
func (c *Context) SetPath(s string) {
|
||||
c.path = s
|
||||
}
|
||||
func (c *Context) SetContext(ctx context.Context) {
|
||||
c.context = ctx
|
||||
}
|
||||
|
||||
func (c *Context) GetT() i18n.TranslateFunc {
|
||||
return c.t
|
||||
}
|
||||
|
||||
func (c *Context) SetLogger(logger mlog.LoggerIFace) {
|
||||
c.logger = logger
|
||||
}
|
||||
|
||||
func (c *Context) Logger() mlog.LoggerIFace {
|
||||
return c.logger
|
||||
func (c *Context) With(f func(ctx CTX) CTX) CTX {
|
||||
return f(c)
|
||||
}
|
||||
|
||||
type CTX interface {
|
||||
Clone() CTX
|
||||
T(string, ...interface{}) string
|
||||
GetT() i18n.TranslateFunc
|
||||
Session() *model.Session
|
||||
RequestId() string
|
||||
IPAddress() string
|
||||
@ -140,16 +158,17 @@ type CTX interface {
|
||||
Path() string
|
||||
UserAgent() string
|
||||
AcceptLanguage() string
|
||||
Context() context.Context
|
||||
SetSession(s *model.Session)
|
||||
SetT(i18n.TranslateFunc)
|
||||
SetRequestId(string)
|
||||
SetIPAddress(string)
|
||||
SetUserAgent(string)
|
||||
SetAcceptLanguage(string)
|
||||
SetPath(string)
|
||||
SetContext(ctx context.Context)
|
||||
GetT() i18n.TranslateFunc
|
||||
SetLogger(mlog.LoggerIFace)
|
||||
Logger() mlog.LoggerIFace
|
||||
Context() context.Context
|
||||
WithT(i18n.TranslateFunc) CTX
|
||||
WithSession(s *model.Session) CTX
|
||||
WithRequestId(string) CTX
|
||||
WithIPAddress(string) CTX
|
||||
WithXForwardedFor(string) CTX
|
||||
WithPath(string) CTX
|
||||
WithUserAgent(string) CTX
|
||||
WithAcceptLanguage(string) CTX
|
||||
WithLogger(mlog.LoggerIFace) CTX
|
||||
WithContext(ctx context.Context) CTX
|
||||
With(func(ctx CTX) CTX) CTX
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user