diff --git a/server/channels/api4/apitestlib.go b/server/channels/api4/apitestlib.go index b5df550140..11d9c27ced 100644 --- a/server/channels/api4/apitestlib.go +++ b/server/channels/api4/apitestlib.go @@ -140,13 +140,12 @@ func setupTestHelper(dbStore store.Store, searchEngine *searchengine.Broker, ent th := &TestHelper{ App: app.New(app.ServerConnector(s.Channels())), Server: s, + Context: request.EmptyContext(testLogger), ConfigStore: configStore, IncludeCacheLayer: includeCache, - Context: request.EmptyContext(testLogger), TestLogger: testLogger, LogBuffer: buffer, } - th.Context.SetLogger(testLogger) if s.Platform().SearchEngine != nil && s.Platform().SearchEngine.BleveEngine != nil && searchEngine != nil { searchEngine.BleveEngine = s.Platform().SearchEngine.BleveEngine diff --git a/server/channels/api4/channel.go b/server/channels/api4/channel.go index 03a37c2327..4d2d8f1a75 100644 --- a/server/channels/api4/channel.go +++ b/server/channels/api4/channel.go @@ -4,7 +4,6 @@ package api4 import ( - "context" "encoding/json" "net/http" "strconv" @@ -1460,9 +1459,8 @@ func getChannelMember(c *Context, w http.ResponseWriter, r *http.Request) { return } - ctx := c.AppContext - ctx.SetContext(app.WithMaster(ctx.Context())) - member, err := c.App.GetChannelMember(ctx, c.Params.ChannelId, c.Params.UserId) + c.AppContext = c.AppContext.With(app.RequestContextWithMaster) + member, err := c.App.GetChannelMember(c.AppContext, c.Params.ChannelId, c.Params.UserId) if err != nil { c.Err = err return @@ -2004,7 +2002,7 @@ func channelMemberCountsByGroup(c *Context, w http.ResponseWriter, r *http.Reque includeTimezones := r.URL.Query().Get("include_timezones") == "true" - channelMemberCounts, appErr := c.App.GetMemberCountsByGroup(app.WithMaster(context.Background()), c.Params.ChannelId, includeTimezones) + channelMemberCounts, appErr := c.App.GetMemberCountsByGroup(c.AppContext.With(app.RequestContextWithMaster), c.Params.ChannelId, includeTimezones) if appErr != nil { c.Err = appErr return diff --git a/server/channels/api4/remote_cluster.go b/server/channels/api4/remote_cluster.go index 93e43259c2..277100d69d 100644 --- a/server/channels/api4/remote_cluster.go +++ b/server/channels/api4/remote_cluster.go @@ -195,7 +195,7 @@ func uploadRemoteData(c *Context, w http.ResponseWriter, r *http.Request) { defer c.LogAuditRec(auditRec) audit.AddEventParameter(auditRec, "upload_id", c.Params.UploadId) - c.AppContext.SetContext(app.WithMaster(c.AppContext.Context())) + c.AppContext = c.AppContext.With(app.RequestContextWithMaster) us, err := c.App.GetUploadSession(c.AppContext, c.Params.UploadId) if err != nil { c.Err = err diff --git a/server/channels/api4/upload.go b/server/channels/api4/upload.go index 1d07c34d44..7891493ff6 100644 --- a/server/channels/api4/upload.go +++ b/server/channels/api4/upload.go @@ -123,7 +123,7 @@ func uploadData(c *Context, w http.ResponseWriter, r *http.Request) { defer c.LogAuditRec(auditRec) audit.AddEventParameter(auditRec, "upload_id", c.Params.UploadId) - c.AppContext.SetContext(app.WithMaster(c.AppContext.Context())) + c.AppContext = c.AppContext.With(app.RequestContextWithMaster) us, err := c.App.GetUploadSession(c.AppContext, c.Params.UploadId) if err != nil { c.Err = err diff --git a/server/channels/api4/user.go b/server/channels/api4/user.go index f242c112b8..2cf563e937 100644 --- a/server/channels/api4/user.go +++ b/server/channels/api4/user.go @@ -1953,11 +1953,12 @@ func login(c *Context, w http.ResponseWriter, r *http.Request) { c.LogAuditWithUserId(user.Id, "authenticated") isMobileDevice := utils.IsMobileRequest(r) - err = c.App.DoLogin(c.AppContext, w, r, user, deviceId, isMobileDevice, false, false) + session, err := c.App.DoLogin(c.AppContext, w, r, user, deviceId, isMobileDevice, false, false) if err != nil { c.Err = err return } + c.AppContext = c.AppContext.WithSession(session) c.LogAuditWithUserId(user.Id, "success") @@ -1995,11 +1996,12 @@ func loginWithDesktopToken(c *Context, w http.ResponseWriter, r *http.Request) { return } - err = c.App.DoLogin(c.AppContext, w, r, user, deviceId, false, false, false) + session, err := c.App.DoLogin(c.AppContext, w, r, user, deviceId, false, false, false) if err != nil { c.Err = err return } + c.AppContext = c.AppContext.WithSession(session) c.App.AttachSessionCookies(c.AppContext, w, r) @@ -2048,12 +2050,13 @@ func loginCWS(c *Context, w http.ResponseWriter, r *http.Request) { audit.AddEventParameterAuditable(auditRec, "user", user) c.LogAuditWithUserId(user.Id, "authenticated") isMobileDevice := utils.IsMobileRequest(r) - err = c.App.DoLogin(c.AppContext, w, r, user, "", isMobileDevice, false, false) + session, err := c.App.DoLogin(c.AppContext, w, r, user, "", isMobileDevice, false, false) if err != nil { c.LogErrorByCode(err) http.Redirect(w, r, *c.App.Config().ServiceSettings.SiteURL, http.StatusFound) return } + c.AppContext = c.AppContext.WithSession(session) c.LogAuditWithUserId(user.Id, "success") c.App.AttachSessionCookies(c.AppContext, w, r) diff --git a/server/channels/app/app_iface.go b/server/channels/app/app_iface.go index e1991a00c3..7530159e24 100644 --- a/server/channels/app/app_iface.go +++ b/server/channels/app/app_iface.go @@ -553,7 +553,7 @@ type AppIface interface { DoEmojisPermissionsMigration() DoGuestRolesCreationMigration() DoLocalRequest(c request.CTX, rawURL string, body []byte) (*http.Response, *model.AppError) - DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile, isOAuthUser, isSaml bool) *model.AppError + DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile, isOAuthUser, isSaml bool) (*model.Session, *model.AppError) DoPostAction(c request.CTX, postID, actionId, userID, selectedOption string) (string, *model.AppError) DoPostActionWithCookie(c request.CTX, postID, actionId, userID, selectedOption string, cookie *model.PostActionCookie) (string, *model.AppError) DoSystemConsoleRolesCreationMigration() @@ -687,7 +687,7 @@ type AppIface interface { GetLatestVersion(latestVersionUrl string) (*model.GithubReleaseInfo, *model.AppError) GetLogs(c request.CTX, page, perPage int) ([]string, *model.AppError) GetLogsSkipSend(page, perPage int, logFilter *model.LogFilter) ([]string, *model.AppError) - GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) + GetMemberCountsByGroup(rctx request.CTX, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) GetMessageForNotification(post *model.Post, translateFunc i18n.TranslateFunc) string GetMultipleEmojiByName(c request.CTX, names []string) ([]*model.Emoji, *model.AppError) GetNewUsersForTeamPage(teamID string, page, perPage int, asAdmin bool, viewRestrictions *model.ViewUsersRestrictions) ([]*model.User, *model.AppError) diff --git a/server/channels/app/channel.go b/server/channels/app/channel.go index e03a088e7e..40af28ba66 100644 --- a/server/channels/app/channel.go +++ b/server/channels/app/channel.go @@ -1666,7 +1666,9 @@ func (a *App) AddChannelMember(c request.CTX, userID string, channel *model.Chan } } else { a.Srv().Go(func() { - a.PostAddToChannelMessage(c, userRequestor, user, channel, opts.PostRootID) + if err := a.PostAddToChannelMessage(c, userRequestor, user, channel, opts.PostRootID); err != nil { + c.Logger().Error("Failed to post AddToChannel message", mlog.Err(err)) + } }) } @@ -2363,7 +2365,9 @@ func (a *App) LeaveChannel(c request.CTX, channelID string, userID string) *mode } a.Srv().Go(func() { - a.postLeaveChannelMessage(c, user, channel) + if err := a.postLeaveChannelMessage(c, user, channel); err != nil { + c.Logger().Error("Failed to post LeaveChannel message", mlog.Err(err)) + } }) return nil @@ -3458,8 +3462,8 @@ func (a *App) ClearChannelMembersCache(c request.CTX, channelID string) error { return nil } -func (a *App) GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) { - channelMemberCounts, err := a.Srv().Store().Channel().GetMemberCountsByGroup(ctx, channelID, includeTimezones) +func (a *App) GetMemberCountsByGroup(rctx request.CTX, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) { + channelMemberCounts, err := a.Srv().Store().Channel().GetMemberCountsByGroup(rctx.Context(), channelID, includeTimezones) if err != nil { return nil, model.NewAppError("GetMemberCountsByGroup", "app.channel.get_member_count.app_error", nil, "", http.StatusInternalServerError).Wrap(err) } diff --git a/server/channels/app/channel_test.go b/server/channels/app/channel_test.go index 802aaedf2d..af04e7f8ff 100644 --- a/server/channels/app/channel_test.go +++ b/server/channels/app/channel_test.go @@ -2152,7 +2152,7 @@ func TestGetMemberCountsByGroup(t *testing.T) { mockChannelStore.On("GetMemberCountsByGroup", context.Background(), "channelID", true).Return(cmc, nil) mockStore.On("Channel").Return(&mockChannelStore) mockStore.On("GetDBSchemaVersion").Return(1, nil) - resp, err := th.App.GetMemberCountsByGroup(context.Background(), "channelID", true) + resp, err := th.App.GetMemberCountsByGroup(th.Context, "channelID", true) require.Nil(t, err) require.ElementsMatch(t, cmc, resp) } diff --git a/server/channels/app/compliance.go b/server/channels/app/compliance.go index d0fcd1cfc2..0d35607ae1 100644 --- a/server/channels/app/compliance.go +++ b/server/channels/app/compliance.go @@ -34,7 +34,7 @@ func (a *App) SaveComplianceReport(rctx request.CTX, job *model.Compliance) (*mo job.Type = model.ComplianceTypeAdhoc - rctx.SetLogger(rctx.Logger().With(job.LoggerFields()...)) + rctx = rctx.WithLogger(rctx.Logger().With(job.LoggerFields()...)) job, err := a.Srv().Store().Compliance().Save(job) if err != nil { @@ -48,11 +48,10 @@ func (a *App) SaveComplianceReport(rctx request.CTX, job *model.Compliance) (*mo } jCopy := job.DeepCopy() - crctx := rctx.Clone() a.Srv().Go(func() { - err := a.Compliance().RunComplianceJob(crctx, jCopy) + err := a.Compliance().RunComplianceJob(rctx, jCopy) if err != nil { - crctx.Logger().Warn("Error running compliance job", mlog.Err(err)) + rctx.Logger().Warn("Error running compliance job", mlog.Err(err)) } }) diff --git a/server/channels/app/context.go b/server/channels/app/context.go index cd5cab3aa1..ea47242ed4 100644 --- a/server/channels/app/context.go +++ b/server/channels/app/context.go @@ -4,16 +4,14 @@ package app import ( - "context" - "github.com/mattermost/mattermost/server/public/plugin" "github.com/mattermost/mattermost/server/public/shared/request" "github.com/mattermost/mattermost/server/v8/channels/store/sqlstore" ) -// WithMaster adds the context value that master DB should be selected for this request. -func WithMaster(ctx context.Context) context.Context { - return sqlstore.WithMaster(ctx) +// RequestContextWithMaster adds the context value that master DB should be selected for this request. +func RequestContextWithMaster(c request.CTX) request.CTX { + return sqlstore.RequestContextWithMaster(c) } func pluginContext(c request.CTX) *plugin.Context { diff --git a/server/channels/app/file.go b/server/channels/app/file.go index 1e40ed83cf..238b53b080 100644 --- a/server/channels/app/file.go +++ b/server/channels/app/file.go @@ -796,11 +796,10 @@ func (a *App) UploadFileX(c request.CTX, channelID, name string, input io.Reader if *a.Config().FileSettings.ExtractContent { infoCopy := *t.fileinfo - crctx := c.Clone() a.Srv().GoBuffered(func() { - err := a.ExtractContentFromFileInfo(crctx, &infoCopy) + err := a.ExtractContentFromFileInfo(c, &infoCopy) if err != nil { - crctx.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id)) + c.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id)) } }) } @@ -1048,11 +1047,10 @@ func (a *App) DoUploadFileExpectModification(c request.CTX, now time.Time, rawTe if *a.Config().FileSettings.ExtractContent { infoCopy := *info - crctx := c.Clone() a.Srv().GoBuffered(func() { - err := a.ExtractContentFromFileInfo(crctx, &infoCopy) + err := a.ExtractContentFromFileInfo(c, &infoCopy) if err != nil { - crctx.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id)) + c.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id)) } }) } diff --git a/server/channels/app/helper_test.go b/server/channels/app/helper_test.go index e4292b4246..34c1d489b6 100644 --- a/server/channels/app/helper_test.go +++ b/server/channels/app/helper_test.go @@ -99,7 +99,6 @@ func setupTestHelper(dbStore store.Store, enterprise bool, includeCacheLayer boo IncludeCacheLayer: includeCacheLayer, ConfigStore: configStore, } - th.Context.SetLogger(testLogger) th.App.UpdateConfig(func(cfg *model.Config) { *cfg.TeamSettings.MaxUsersPerTeam = 50 }) th.App.UpdateConfig(func(cfg *model.Config) { *cfg.RateLimitSettings.Enable = false }) diff --git a/server/channels/app/import.go b/server/channels/app/import.go index 9390d4fe0c..4a630ba562 100644 --- a/server/channels/app/import.go +++ b/server/channels/app/import.go @@ -268,7 +268,7 @@ func (a *App) bulkImport(c request.CTX, jsonlReader io.Reader, attachmentsReader linesChan = make(chan imports.LineImportWorkerData, workers) for i := 0; i < workers; i++ { wg.Add(1) - go a.bulkImportWorker(c.Clone(), dryRun, &wg, linesChan, errorsChan) + go a.bulkImportWorker(c, dryRun, &wg, linesChan, errorsChan) } } diff --git a/server/channels/app/login.go b/server/channels/app/login.go index 7c5d62610c..0bfc210803 100644 --- a/server/channels/app/login.go +++ b/server/channels/app/login.go @@ -156,7 +156,7 @@ func (a *App) GetUserForLogin(c request.CTX, id, loginId string) (*model.User, * return nil, model.NewAppError("GetUserForLogin", "store.sql_user.get_for_login.app_error", nil, "", http.StatusBadRequest) } -func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile, isOAuthUser, isSaml bool) *model.AppError { +func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile, isOAuthUser, isSaml bool) (*model.Session, *model.AppError) { var rejectionReason string pluginContext := pluginContext(c) a.ch.RunMultiHook(func(hooks plugin.Hooks) bool { @@ -165,7 +165,7 @@ func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, use }, plugin.UserWillLogInID) if rejectionReason != "" { - return model.NewAppError("DoLogin", "Login rejected by plugin: "+rejectionReason, nil, "", http.StatusBadRequest) + return nil, model.NewAppError("DoLogin", "Login rejected by plugin: "+rejectionReason, nil, "", http.StatusBadRequest) } session := &model.Session{UserId: user.Id, Roles: user.GetRawRoles(), DeviceId: deviceID, IsOAuth: false, Props: map[string]string{ @@ -181,7 +181,7 @@ func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, use // A special case where we logout of all other sessions with the same Id if err := a.RevokeSessionsForDeviceId(c, user.Id, deviceID, ""); err != nil { err.StatusCode = http.StatusInternalServerError - return err + return nil, err } } else if isMobile { a.ch.srv.platform.SetSessionExpireInHours(session, *a.Config().ServiceSettings.SessionLengthMobileInHours) @@ -210,12 +210,12 @@ func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, use var err *model.AppError if session, err = a.CreateSession(c, session); err != nil { err.StatusCode = http.StatusInternalServerError - return err + return nil, err } w.Header().Set(model.HeaderToken, session.Token) - c.SetSession(session) + c = c.WithSession(session) if a.Srv().License() != nil && *a.Srv().License().Features.LDAP && a.Ldap() != nil { userVal := *user sessionVal := *session @@ -231,7 +231,7 @@ func (a *App) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, use }, plugin.UserHasLoggedInID) }) - return nil + return session, nil } func (a *App) AttachCloudSessionCookie(c request.CTX, w http.ResponseWriter, r *http.Request) { diff --git a/server/channels/app/opentracing/opentracing_layer.go b/server/channels/app/opentracing/opentracing_layer.go index fbc5a75bc1..d8aa0d566c 100644 --- a/server/channels/app/opentracing/opentracing_layer.go +++ b/server/channels/app/opentracing/opentracing_layer.go @@ -3828,7 +3828,7 @@ func (a *OpenTracingAppLayer) DoLocalRequest(c request.CTX, rawURL string, body return resultVar0, resultVar1 } -func (a *OpenTracingAppLayer) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile bool, isOAuthUser bool, isSaml bool) *model.AppError { +func (a *OpenTracingAppLayer) DoLogin(c request.CTX, w http.ResponseWriter, r *http.Request, user *model.User, deviceID string, isMobile bool, isOAuthUser bool, isSaml bool) (*model.Session, *model.AppError) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.DoLogin") @@ -3840,14 +3840,14 @@ func (a *OpenTracingAppLayer) DoLogin(c request.CTX, w http.ResponseWriter, r *h }() defer span.Finish() - resultVar0 := a.app.DoLogin(c, w, r, user, deviceID, isMobile, isOAuthUser, isSaml) + resultVar0, resultVar1 := a.app.DoLogin(c, w, r, user, deviceID, isMobile, isOAuthUser, isSaml) - if resultVar0 != nil { - span.LogFields(spanlog.Error(resultVar0)) + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) ext.Error.Set(span, true) } - return resultVar0 + return resultVar0, resultVar1 } func (a *OpenTracingAppLayer) DoPermissionsMigrations() error { @@ -7320,7 +7320,7 @@ func (a *OpenTracingAppLayer) GetMarketplacePlugins(filter *model.MarketplacePlu return resultVar0, resultVar1 } -func (a *OpenTracingAppLayer) GetMemberCountsByGroup(ctx context.Context, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) { +func (a *OpenTracingAppLayer) GetMemberCountsByGroup(rctx request.CTX, channelID string, includeTimezones bool) ([]*model.ChannelMemberCountByGroup, *model.AppError) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetMemberCountsByGroup") @@ -7332,7 +7332,7 @@ func (a *OpenTracingAppLayer) GetMemberCountsByGroup(ctx context.Context, channe }() defer span.Finish() - resultVar0, resultVar1 := a.app.GetMemberCountsByGroup(ctx, channelID, includeTimezones) + resultVar0, resultVar1 := a.app.GetMemberCountsByGroup(rctx, channelID, includeTimezones) if resultVar1 != nil { span.LogFields(spanlog.Error(resultVar1)) diff --git a/server/channels/app/plugin_api.go b/server/channels/app/plugin_api.go index 4b723aac7d..714a4abe8b 100644 --- a/server/channels/app/plugin_api.go +++ b/server/channels/app/plugin_api.go @@ -1263,7 +1263,7 @@ func (api *PluginAPI) UploadData(us *model.UploadSession, rd io.Reader) (*model. func (api *PluginAPI) GetUploadSession(uploadID string) (*model.UploadSession, error) { // We want to fetch from master DB to avoid a potential read-after-write on the plugin side. - api.ctx.SetContext(WithMaster(api.ctx.Context())) + api.ctx = api.ctx.With(RequestContextWithMaster) fi, err := api.app.GetUploadSession(api.ctx, uploadID) if err != nil { return nil, err diff --git a/server/channels/app/plugin_hooks_test.go b/server/channels/app/plugin_hooks_test.go index 9a6fbc6b0c..34268ec006 100644 --- a/server/channels/app/plugin_hooks_test.go +++ b/server/channels/app/plugin_hooks_test.go @@ -727,9 +727,10 @@ func TestUserWillLogIn_Blocked(t *testing.T) { r := &http.Request{} w := httptest.NewRecorder() - err = th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false) + session, err := th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false) assert.Contains(t, err.Id, "Login rejected by plugin", "Expected Login rejected by plugin, got %s", err.Id) + assert.Nil(t, session) } func TestUserWillLogInIn_Passed(t *testing.T) { @@ -766,10 +767,11 @@ func TestUserWillLogInIn_Passed(t *testing.T) { r := &http.Request{} w := httptest.NewRecorder() - err = th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false) + session, err := th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false) assert.Nil(t, err, "Expected nil, got %s", err) - assert.Equal(t, th.Context.Session().UserId, th.BasicUser.Id) + require.NotNil(t, session) + assert.Equal(t, session.UserId, th.BasicUser.Id) } func TestUserHasLoggedIn(t *testing.T) { @@ -807,9 +809,10 @@ func TestUserHasLoggedIn(t *testing.T) { r := &http.Request{} w := httptest.NewRecorder() - err = th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false) + session, err := th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false) assert.Nil(t, err, "Expected nil, got %s", err) + assert.NotNil(t, session) time.Sleep(2 * time.Second) diff --git a/server/channels/app/post.go b/server/channels/app/post.go index 2fb26b05fe..cb908970eb 100644 --- a/server/channels/app/post.go +++ b/server/channels/app/post.go @@ -532,7 +532,7 @@ func (a *App) handlePostEvents(c request.CTX, post *model.Post, user *model.User a.Srv().Go(func() { _, err := a.SendAutoResponseIfNecessary(c, channel, user, post) if err != nil { - mlog.Error("Failed to send auto response", mlog.String("user_id", user.Id), mlog.String("post_id", post.Id), mlog.Err(err)) + c.Logger().Error("Failed to send auto response", mlog.String("user_id", user.Id), mlog.String("post_id", post.Id), mlog.Err(err)) } }) } @@ -540,7 +540,7 @@ func (a *App) handlePostEvents(c request.CTX, post *model.Post, user *model.User if triggerWebhooks { a.Srv().Go(func() { if err := a.handleWebhookEvents(c, post, team, channel, user); err != nil { - mlog.Error(err.Error()) + c.Logger().Error("Failed to handle webhook event", mlog.Err(err)) } }) } @@ -1378,7 +1378,7 @@ func (a *App) DeletePost(c request.CTX, postID, deleteByID string) (*model.Post, a.Srv().Go(func() { if err = a.RemoveNotifications(c, post, channel); err != nil { - a.Log().Error("DeletePost failed to delete notification", mlog.Err(err)) + c.Logger().Error("DeletePost failed to delete notification", mlog.Err(err)) } }) diff --git a/server/channels/app/team_test.go b/server/channels/app/team_test.go index b2ab165f3e..aa0c48dc46 100644 --- a/server/channels/app/team_test.go +++ b/server/channels/app/team_test.go @@ -19,6 +19,7 @@ import ( "github.com/stretchr/testify/require" "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/shared/request" "github.com/mattermost/mattermost/server/v8/channels/app/email" emailmocks "github.com/mattermost/mattermost/server/v8/channels/app/email/mocks" "github.com/mattermost/mattermost/server/v8/channels/app/teams" @@ -1027,7 +1028,12 @@ func TestLeaveTeamPanic(t *testing.T) { mockLicenseStore.On("Get", "").Return(&model.LicenseRecord{}, nil) mockTeamStore := mocks.TeamStore{} - mockTeamStore.On("GetMember", sqlstore.RequestContextWithMaster(th.Context), "myteam", "userID").Return(&model.TeamMember{TeamId: "myteam", UserId: "userID"}, nil) + mockTeamStore.On("GetMember", mock.AnythingOfType("*request.Context"), "myteam", "userID").Return(&model.TeamMember{TeamId: "myteam", UserId: "userID"}, nil).Run(func(args mock.Arguments) { + c, ok := args[0].(request.CTX) + require.True(t, ok) + + sqlstore.HasMaster(c.Context()) + }) mockTeamStore.On("UpdateMember", mock.Anything).Return(nil, errors.New("repro error")) // This is the line that triggers the error mockStore.On("Channel").Return(&mockChannelStore) diff --git a/server/channels/app/upload.go b/server/channels/app/upload.go index 3a0f231428..f147467ea1 100644 --- a/server/channels/app/upload.go +++ b/server/channels/app/upload.go @@ -203,7 +203,7 @@ func (a *App) UploadData(c request.CTX, us *model.UploadSession, rd io.Reader) ( }() // fetch the session from store to check for inconsistencies. - c.SetContext(WithMaster(c.Context())) + c = c.With(RequestContextWithMaster) if storedSession, err := a.GetUploadSession(c, us.Id); err != nil { return nil, err } else if us.FileOffset != storedSession.FileOffset { @@ -318,11 +318,10 @@ func (a *App) UploadData(c request.CTX, us *model.UploadSession, rd io.Reader) ( if *a.Config().FileSettings.ExtractContent { infoCopy := *info - crctx := c.Clone() a.Srv().Go(func() { - err := a.ExtractContentFromFileInfo(crctx, &infoCopy) + err := a.ExtractContentFromFileInfo(c, &infoCopy) if err != nil { - crctx.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id)) + c.Logger().Error("Failed to extract file content", mlog.Err(err), mlog.String("fileInfoId", infoCopy.Id)) } }) } diff --git a/server/channels/jobs/batch_migration_worker.go b/server/channels/jobs/batch_migration_worker.go index 464170fa8e..183450ad20 100644 --- a/server/channels/jobs/batch_migration_worker.go +++ b/server/channels/jobs/batch_migration_worker.go @@ -5,6 +5,7 @@ package jobs import ( "net/http" + "sync/atomic" "time" "github.com/mattermost/mattermost/server/public/model" @@ -33,8 +34,9 @@ type BatchMigrationWorker struct { store store.Store app BatchMigrationWorkerAppIFace - stop chan bool + stop chan struct{} stopped chan bool + closed atomic.Bool jobs chan model.Job migrationKey string @@ -49,7 +51,7 @@ func MakeBatchMigrationWorker(jobServer *JobServer, store store.Store, app Batch logger: jobServer.Logger().With(mlog.String("worker_name", migrationKey)), store: store, app: app, - stop: make(chan bool, 1), + stop: make(chan struct{}), stopped: make(chan bool, 1), jobs: make(chan model.Job), migrationKey: migrationKey, @@ -64,7 +66,9 @@ func (worker *BatchMigrationWorker) Run() { worker.logger.Debug("Worker started") // We have to re-assign the stop channel again, because // it might happen that the job was restarted due to a config change. - worker.stop = make(chan bool, 1) + if worker.closed.CompareAndSwap(true, false) { + worker.stop = make(chan struct{}) + } defer func() { worker.logger.Debug("Worker finished") @@ -84,6 +88,11 @@ func (worker *BatchMigrationWorker) Run() { // Stop interrupts the worker even if the migration has not yet completed. func (worker *BatchMigrationWorker) Stop() { + // Set to close, and if already closed before, then return. + if !worker.closed.CompareAndSwap(false, true) { + return + } + worker.logger.Debug("Worker stopping") close(worker.stop) <-worker.stopped diff --git a/server/channels/jobs/helper_test.go b/server/channels/jobs/helper_test.go index fb2867ab9d..c80f612899 100644 --- a/server/channels/jobs/helper_test.go +++ b/server/channels/jobs/helper_test.go @@ -88,7 +88,6 @@ func setupTestHelper(dbStore store.Store, enterprise bool, includeCacheLayer boo IncludeCacheLayer: includeCacheLayer, ConfigStore: configStore, } - th.Context.SetLogger(testLogger) prevListenAddress := *th.App.Config().ServiceSettings.ListenAddress th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.ListenAddress = "localhost:0" }) diff --git a/server/channels/jobs/migrations/worker.go b/server/channels/jobs/migrations/worker.go index ffc81026d1..5cf60ad087 100644 --- a/server/channels/jobs/migrations/worker.go +++ b/server/channels/jobs/migrations/worker.go @@ -100,10 +100,10 @@ func (worker *Worker) DoJob(job *model.Job) { return } - cancelContext := request.EmptyContext(worker.logger) + var cancelContext request.CTX = request.EmptyContext(worker.logger) cancelCtx, cancelCancelWatcher := context.WithCancel(context.Background()) cancelWatcherChan := make(chan struct{}, 1) - cancelContext.SetContext(cancelCtx) + cancelContext = cancelContext.WithContext(cancelCtx) go worker.jobServer.CancellationWatcher(cancelContext, job.Id, cancelWatcherChan) defer cancelCancelWatcher() diff --git a/server/channels/jobs/s3_path_migration/s3_path_migration.go b/server/channels/jobs/s3_path_migration/s3_path_migration.go index 3732fbf133..28203261a4 100644 --- a/server/channels/jobs/s3_path_migration/s3_path_migration.go +++ b/server/channels/jobs/s3_path_migration/s3_path_migration.go @@ -29,7 +29,7 @@ type S3PathMigrationWorker struct { store store.Store fileBackend *filestore.S3FileBackend - stop chan bool + stop chan struct{} stopped chan bool jobs chan model.Job } @@ -45,7 +45,7 @@ func MakeWorker(jobServer *jobs.JobServer, store store.Store, fileBackend filest logger: jobServer.Logger().With(mlog.String("worker_name", workerName)), store: store, fileBackend: s3Backend, - stop: make(chan bool, 1), + stop: make(chan struct{}), stopped: make(chan bool, 1), jobs: make(chan model.Job), } @@ -56,7 +56,7 @@ func (worker *S3PathMigrationWorker) Run() { worker.logger.Debug("Worker started") // We have to re-assign the stop channel again, because // it might happen that the job was restarted due to a config change. - worker.stop = make(chan bool, 1) + worker.stop = make(chan struct{}, 1) defer func() { worker.logger.Debug("Worker finished") diff --git a/server/channels/store/sqlstore/context.go b/server/channels/store/sqlstore/context.go index 8b43c5e415..114f9c4718 100644 --- a/server/channels/store/sqlstore/context.go +++ b/server/channels/store/sqlstore/context.go @@ -31,7 +31,7 @@ func WithMaster(ctx context.Context) context.Context { // RequestContextWithMaster adds the context value that master DB should be selected for this request. func RequestContextWithMaster(c request.CTX) request.CTX { ctx := WithMaster(c.Context()) - c.SetContext(ctx) + c = c.WithContext(ctx) return c } diff --git a/server/channels/store/sqlstore/context_test.go b/server/channels/store/sqlstore/context_test.go index e477cb3c56..6bc60c12cf 100644 --- a/server/channels/store/sqlstore/context_test.go +++ b/server/channels/store/sqlstore/context_test.go @@ -26,39 +26,30 @@ func TestRequestContextWithMaster(t *testing.T) { assert.True(t, HasMaster(rctx.Context())) }) - t.Run("directly assigning does cause the child to alter the parent", func(t *testing.T) { - var rctx request.CTX = request.TestContext(t) - rctxClone := rctx - rctxClone = RequestContextWithMaster(rctxClone) - - assert.True(t, HasMaster(rctx.Context())) - assert.True(t, HasMaster(rctxClone.Context())) - }) - - t.Run("values get copied from parent", func(t *testing.T) { + t.Run("values get copied from original context", func(t *testing.T) { var rctx request.CTX = request.TestContext(t) rctx = RequestContextWithMaster(rctx) - rctxClone := rctx.Clone() + rctxCopy := rctx assert.True(t, HasMaster(rctx.Context())) - assert.True(t, HasMaster(rctxClone.Context())) + assert.True(t, HasMaster(rctxCopy.Context())) }) - t.Run("changing the child does not alter the parent", func(t *testing.T) { + t.Run("directly assigning does not cause the copy to alter the original context", func(t *testing.T) { var rctx request.CTX = request.TestContext(t) - rctxClone := rctx.Clone() - rctxClone = RequestContextWithMaster(rctxClone) + rctxCopy := rctx + rctxCopy = RequestContextWithMaster(rctxCopy) assert.False(t, HasMaster(rctx.Context())) - assert.True(t, HasMaster(rctxClone.Context())) + assert.True(t, HasMaster(rctxCopy.Context())) }) - t.Run("changing the parent does not alter the child", func(t *testing.T) { + t.Run("directly assigning does not cause the original context to alter the copy", func(t *testing.T) { var rctx request.CTX = request.TestContext(t) - rctxClone := rctx.Clone() + rctxCopy := rctx rctx = RequestContextWithMaster(rctx) assert.True(t, HasMaster(rctx.Context())) - assert.False(t, HasMaster(rctxClone.Context())) + assert.False(t, HasMaster(rctxCopy.Context())) }) } diff --git a/server/channels/web/context.go b/server/channels/web/context.go index d1890d2598..76e6a55e3f 100644 --- a/server/channels/web/context.go +++ b/server/channels/web/context.go @@ -20,7 +20,7 @@ import ( type Context struct { App app.AppIface - AppContext *request.Context + AppContext request.CTX Logger *mlog.Logger Params *Params Err *model.AppError diff --git a/server/channels/web/context_test.go b/server/channels/web/context_test.go index a294d96a26..79f8f1d8ea 100644 --- a/server/channels/web/context_test.go +++ b/server/channels/web/context_test.go @@ -71,7 +71,7 @@ func TestMfaRequired(t *testing.T) { th.App.Srv().SetLicense(model.NewTestLicense("mfa")) - th.Context.SetSession(&model.Session{Id: "abc", UserId: "userid"}) + th.Context = th.Context.WithSession(&model.Session{Id: "abc", UserId: "userid"}) th.App.UpdateConfig(func(cfg *model.Config) { *cfg.AnnouncementSettings.UserNoticesEnabled = false diff --git a/server/channels/web/handlers.go b/server/channels/web/handlers.go index cc981f7325..bfd9aebf71 100644 --- a/server/channels/web/handlers.go +++ b/server/channels/web/handlers.go @@ -205,7 +205,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } span.Finish() }() - c.AppContext.SetContext(ctx) + c.AppContext = c.AppContext.WithContext(ctx) tmpSrv := *c.App.Srv() tmpSrv.SetStore(opentracinglayer.New(c.App.Srv().Store(), ctx)) @@ -285,7 +285,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } else if !session.IsOAuth && tokenLocation == app.TokenLocationQueryString { c.Err = model.NewAppError("ServeHTTP", "api.context.token_provided.app_error", nil, "token="+token, http.StatusUnauthorized) } else { - c.AppContext.SetSession(session) + c.AppContext = c.AppContext.WithSession(session) } // Rate limit by UserID @@ -301,7 +301,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { c.Logger.Warn("Invalid CWS token", mlog.Err(err)) c.Err = err } else { - c.AppContext.SetSession(session) + c.AppContext = c.AppContext.WithSession(session) } } else if token != "" && c.App.Channels().License() != nil && c.App.Channels().License().HasRemoteClusterService() && tokenLocation == app.TokenLocationRemoteClusterHeader { // Get the remote cluster @@ -315,7 +315,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { c.Logger.Warn("Invalid remote cluster token", mlog.Err(err)) c.Err = err } else { - c.AppContext.SetSession(session) + c.AppContext = c.AppContext.WithSession(session) } } } @@ -327,7 +327,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { mlog.String("user_id", c.AppContext.Session().UserId), mlog.String("method", r.Method), ) - c.AppContext.SetLogger(c.Logger) + c.AppContext = c.AppContext.WithLogger(c.Logger) if c.Err == nil && h.RequireSession { c.SessionRequired() @@ -354,7 +354,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // shape IP:PORT (it will be "@" in Linux, for example) isLocalOrigin := !strings.Contains(r.RemoteAddr, ":") if *c.App.Config().ServiceSettings.EnableLocalMode && isLocalOrigin { - c.AppContext.SetSession(&model.Session{Local: true}) + c.AppContext = c.AppContext.WithSession(&model.Session{Local: true}) } else if !isLocalOrigin { c.Err = model.NewAppError("", "api.context.local_origin_required.app_error", nil, "LocalOriginRequired", http.StatusUnauthorized) } @@ -501,7 +501,7 @@ func (h *Handler) checkCSRFToken(c *Context, r *http.Request, token string, toke } if !csrfCheckPassed { - c.AppContext.SetSession(&model.Session{}) + c.AppContext = c.AppContext.WithSession(&model.Session{}) c.Err = model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token+" Appears to be a CSRF attempt", http.StatusUnauthorized) } } diff --git a/server/channels/web/oauth.go b/server/channels/web/oauth.go index 67a9890c27..2b3f232e69 100644 --- a/server/channels/web/oauth.go +++ b/server/channels/web/oauth.go @@ -321,13 +321,14 @@ func completeOAuth(c *Context, w http.ResponseWriter, r *http.Request) { } else if action == model.OAuthActionSSOToEmail { redirectURL = app.GetProtocol(r) + "://" + r.Host + "/claim?email=" + url.QueryEscape(props["email"]) } else { - err = c.App.DoLogin(c.AppContext, w, r, user, "", isMobile, false, false) + session, err := c.App.DoLogin(c.AppContext, w, r, user, "", isMobile, false, false) if err != nil { err.Translate(c.AppContext.T) mlog.Error(err.Error()) renderError(err) return } + c.AppContext = c.AppContext.WithSession(session) // Old mobile version if isMobile && !hasRedirectURL { diff --git a/server/channels/web/oauth_test.go b/server/channels/web/oauth_test.go index 297025bcab..4e60fecac8 100644 --- a/server/channels/web/oauth_test.go +++ b/server/channels/web/oauth_test.go @@ -18,8 +18,6 @@ import ( "github.com/stretchr/testify/require" "github.com/mattermost/mattermost/server/public/model" - "github.com/mattermost/mattermost/server/public/shared/i18n" - "github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/mattermost/server/public/shared/request" "github.com/mattermost/mattermost/server/v8/channels/utils" "github.com/mattermost/mattermost/server/v8/einterfaces" @@ -397,20 +395,18 @@ func TestMobileLoginWithOAuth(t *testing.T) { c := &Context{ App: th.App, AppContext: th.Context, + Logger: th.TestLogger, Params: &Params{ Service: "gitlab", }, } - var siteURL = "http://localhost:8065" + siteURL := "http://localhost:8065" th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.SiteURL = siteURL *cfg.GitLabSettings.Enable = true }) - translationFunc := i18n.GetUserTranslations("en") - c.AppContext.SetT(translationFunc) - c.Logger = th.TestLogger provider := &MattermostTestProvider{} einterfaces.RegisterOAuthProvider(model.ServiceGitlab, provider) @@ -617,14 +613,12 @@ func TestOAuthComplete_ErrorMessages(t *testing.T) { c := &Context{ App: th.App, AppContext: th.Context, + Logger: th.TestLogger, Params: &Params{ Service: "gitlab", }, } - translationFunc := i18n.GetUserTranslations("en") - c.AppContext.SetT(translationFunc) - c.Logger = mlog.CreateConsoleTestLogger(t) th.App.UpdateConfig(func(cfg *model.Config) { *cfg.GitLabSettings.Enable = true }) th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.EnableOAuthServiceProvider = true }) provider := &MattermostTestProvider{} diff --git a/server/channels/web/saml.go b/server/channels/web/saml.go index 798d53e9d0..e376db98f5 100644 --- a/server/channels/web/saml.go +++ b/server/channels/web/saml.go @@ -178,11 +178,12 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) { auditRec.AddMeta("obtained_user_id", user.Id) c.LogAuditWithUserId(user.Id, "obtained user") - err = c.App.DoLogin(c.AppContext, w, r, user, "", isMobile, false, true) + session, err := c.App.DoLogin(c.AppContext, w, r, user, "", isMobile, false, true) if err != nil { handleError(err) return } + c.AppContext = c.AppContext.WithSession(session) auditRec.Success() c.LogAuditWithUserId(user.Id, "success") diff --git a/server/channels/web/web_test.go b/server/channels/web/web_test.go index 70df148dbd..f96337dd89 100644 --- a/server/channels/web/web_test.go +++ b/server/channels/web/web_test.go @@ -33,7 +33,7 @@ var URL string type TestHelper struct { App app.AppIface - Context *request.Context + Context request.CTX Server *app.Server Web *Web @@ -141,7 +141,6 @@ func setupTestHelper(tb testing.TB, includeCacheLayer bool, options []app.Option IncludeCacheLayer: includeCacheLayer, TestLogger: testLogger, } - th.Context.SetLogger(testLogger) return th } diff --git a/server/cmd/mattermost/commands/export.go b/server/cmd/mattermost/commands/export.go index 042e549672..aa8990f444 100644 --- a/server/cmd/mattermost/commands/export.go +++ b/server/cmd/mattermost/commands/export.go @@ -138,10 +138,10 @@ func scheduleExportCmdF(command *cobra.Command, args []string) error { defer cancel() } - c := request.EmptyContext(a.Log()) - c.SetContext(ctx) + var rctx request.CTX = request.EmptyContext(a.Log()) + rctx = rctx.WithContext(ctx) - job, err := messageExportI.StartSynchronizeJob(c, startTime) + job, err := messageExportI.StartSynchronizeJob(rctx, startTime) if err != nil || job.Status == model.JobStatusError || job.Status == model.JobStatusCanceled { CommandPrintErrorln("ERROR: Message export job failed. Please check the server logs") } else { diff --git a/server/platform/services/searchengine/bleveengine/indexer/indexing_job.go b/server/platform/services/searchengine/bleveengine/indexer/indexing_job.go index 4e042870a8..5d17133b1b 100644 --- a/server/platform/services/searchengine/bleveengine/indexer/indexing_job.go +++ b/server/platform/services/searchengine/bleveengine/indexer/indexing_job.go @@ -253,10 +253,10 @@ func (worker *BleveIndexerWorker) DoJob(job *model.Job) { progress.TotalFilesCount = count } - cancelContext := request.EmptyContext(worker.logger) + var cancelContext request.CTX = request.EmptyContext(worker.logger) cancelCtx, cancelCancelWatcher := context.WithCancel(context.Background()) cancelWatcherChan := make(chan struct{}, 1) - cancelContext.SetContext(cancelCtx) + cancelContext = cancelContext.WithContext(cancelCtx) go worker.jobServer.CancellationWatcher(cancelContext, job.Id, cancelWatcherChan) defer cancelCancelWatcher() diff --git a/server/public/shared/request/context.go b/server/public/shared/request/context.go index bd073794cd..18b783e53b 100644 --- a/server/public/shared/request/context.go +++ b/server/public/shared/request/context.go @@ -22,8 +22,7 @@ type Context struct { userAgent string acceptLanguage string logger mlog.LoggerIFace - - context context.Context + context context.Context } func NewContext(ctx context.Context, requestId, ipAddress, xForwardedFor, path, userAgent, acceptLanguage string, t i18n.TranslateFunc) *Context { @@ -54,8 +53,8 @@ func TestContext(t testing.TB) *Context { return EmptyContext(logger) } -// Clone creates a shallow copy of Context, allowing clones to apply per-request changes. -func (c *Context) Clone() CTX { +// clone creates a shallow copy of Context, allowing clones to apply per-request changes. +func (c *Context) clone() *Context { cCopy := *c return &cCopy } @@ -63,6 +62,9 @@ func (c *Context) Clone() CTX { func (c *Context) T(translationID string, args ...any) string { return c.t(translationID, args...) } +func (c *Context) GetT() i18n.TranslateFunc { + return c.t +} func (c *Context) Session() *model.Session { return &c.session } @@ -84,55 +86,71 @@ func (c *Context) UserAgent() string { func (c *Context) AcceptLanguage() string { return c.acceptLanguage } - +func (c *Context) Logger() mlog.LoggerIFace { + return c.logger +} func (c *Context) Context() context.Context { return c.context } -func (c *Context) SetSession(s *model.Session) { - c.session = *s +func (c *Context) WithT(t i18n.TranslateFunc) CTX { + rctx := c.clone() + rctx.t = t + return rctx +} +func (c *Context) WithSession(s *model.Session) CTX { + rctx := c.clone() + rctx.session = *s + return rctx +} +func (c *Context) WithRequestId(s string) CTX { + rctx := c.clone() + rctx.requestId = s + return rctx +} +func (c *Context) WithIPAddress(s string) CTX { + rctx := c.clone() + rctx.ipAddress = s + return rctx +} +func (c *Context) WithXForwardedFor(s string) CTX { + rctx := c.clone() + rctx.xForwardedFor = s + return rctx +} +func (c *Context) WithPath(s string) CTX { + rctx := c.clone() + rctx.path = s + return rctx +} +func (c *Context) WithUserAgent(s string) CTX { + rctx := c.clone() + rctx.userAgent = s + return rctx +} +func (c *Context) WithAcceptLanguage(s string) CTX { + rctx := c.clone() + rctx.acceptLanguage = s + return rctx +} +func (c *Context) WithContext(ctx context.Context) CTX { + rctx := c.clone() + rctx.context = ctx + return rctx +} +func (c *Context) WithLogger(logger mlog.LoggerIFace) CTX { + rctx := c.clone() + rctx.logger = logger + return rctx } -func (c *Context) SetT(t i18n.TranslateFunc) { - c.t = t -} -func (c *Context) SetRequestId(s string) { - c.requestId = s -} -func (c *Context) SetIPAddress(s string) { - c.ipAddress = s -} -func (c *Context) SetXForwardedFor(s string) { - c.xForwardedFor = s -} -func (c *Context) SetUserAgent(s string) { - c.userAgent = s -} -func (c *Context) SetAcceptLanguage(s string) { - c.acceptLanguage = s -} -func (c *Context) SetPath(s string) { - c.path = s -} -func (c *Context) SetContext(ctx context.Context) { - c.context = ctx -} - -func (c *Context) GetT() i18n.TranslateFunc { - return c.t -} - -func (c *Context) SetLogger(logger mlog.LoggerIFace) { - c.logger = logger -} - -func (c *Context) Logger() mlog.LoggerIFace { - return c.logger +func (c *Context) With(f func(ctx CTX) CTX) CTX { + return f(c) } type CTX interface { - Clone() CTX T(string, ...interface{}) string + GetT() i18n.TranslateFunc Session() *model.Session RequestId() string IPAddress() string @@ -140,16 +158,17 @@ type CTX interface { Path() string UserAgent() string AcceptLanguage() string - Context() context.Context - SetSession(s *model.Session) - SetT(i18n.TranslateFunc) - SetRequestId(string) - SetIPAddress(string) - SetUserAgent(string) - SetAcceptLanguage(string) - SetPath(string) - SetContext(ctx context.Context) - GetT() i18n.TranslateFunc - SetLogger(mlog.LoggerIFace) Logger() mlog.LoggerIFace + Context() context.Context + WithT(i18n.TranslateFunc) CTX + WithSession(s *model.Session) CTX + WithRequestId(string) CTX + WithIPAddress(string) CTX + WithXForwardedFor(string) CTX + WithPath(string) CTX + WithUserAgent(string) CTX + WithAcceptLanguage(string) CTX + WithLogger(mlog.LoggerIFace) CTX + WithContext(ctx context.Context) CTX + With(func(ctx CTX) CTX) CTX }