MM-30026: Use DB master when getting team members from a session (#16170)

* MM-30026: Use DB master when getting team members from a session

A race condition happens when the read-replica isn't updated yet
by the time a session expiry message reaches another node in the cluster.

Here is the sequence of events that can cause it:
- Server1 gets any request which has to wipe session cache.

- The SQL query is written to DB master, and a cluster message is propagated
to clear the session cache for that user.

- Now before the read-replica is updated with the master’s update,
the cluster message reaches Server2. The session cache is wiped out for that user.

- _Any random_ request for that user hits Server2. Does NOT have to be
the update team name request. The request does not find the value
in session cache, because it’s wiped off, and picks it up from the DB.
Surprise surprise, it gets the stale value. Sticks it into the cache.

By now, the read-replica is updated. But guess what, we aren’t going to
ask the DB anymore, because we have it in the cache. And the cache has the stale value.

We use a temporary approach for now by introducing a context in the DB calls so that
the useMaster information can be easily passed. And this has the added advantage of
reusing the same context for future DB calls in case it happens. And we can also
add more context keys as needed.

A proper approach needs some architectural changes. See the issue for more details.

```release-note
Fixed a bug where a session will hold on to a cached value
in an HA setup with read-replicas configured.
```

* incorporate review comments

Co-authored-by: Mattermod <mattermod@users.noreply.github.com>
This commit is contained in:
Agniva De Sarker
2020-11-10 10:43:45 +05:30
committed by GitHub
parent 6c71fbaebd
commit 39b5b601f8
14 changed files with 98 additions and 26 deletions

View File

@@ -5,6 +5,7 @@ package app
import (
"bytes"
"context"
"crypto/sha1"
"errors"
"fmt"
@@ -677,7 +678,7 @@ func (a *App) importUserTeams(user *model.User, data *[]UserTeamImportData) *mod
isGuestByTeamId := map[string]bool{}
isUserByTeamId := map[string]bool{}
isAdminByTeamId := map[string]bool{}
existingMemberships, nErr := a.Srv().Store.Team().GetTeamsForUser(user.Id)
existingMemberships, nErr := a.Srv().Store.Team().GetTeamsForUser(context.Background(), user.Id)
if nErr != nil {
return model.NewAppError("importUserTeams", "app.team.get_members.app_error", nil, nErr.Error(), http.StatusInternalServerError)
}

View File

@@ -4,6 +4,7 @@
package app
import (
"context"
"io/ioutil"
"os"
"path/filepath"
@@ -1711,7 +1712,7 @@ func TestImportUserTeams(t *testing.T) {
} else {
require.Nil(t, err)
}
teamMembers, nErr := th.App.Srv().Store.Team().GetTeamsForUser(user.Id)
teamMembers, nErr := th.App.Srv().Store.Team().GetTeamsForUser(context.Background(), user.Id)
require.Nil(t, nErr)
require.Len(t, teamMembers, tc.expectedUserTeams)
if tc.expectedUserTeams == 1 {

View File

@@ -5,6 +5,7 @@ package app
import (
"bytes"
"context"
"errors"
"fmt"
"image"
@@ -1011,7 +1012,7 @@ func (a *App) GetTeamMember(teamId, userId string) (*model.TeamMember, *model.Ap
}
func (a *App) GetTeamMembersForUser(userId string) ([]*model.TeamMember, *model.AppError) {
teamMembers, err := a.Srv().Store.Team().GetTeamsForUser(userId)
teamMembers, err := a.Srv().Store.Team().GetTeamsForUser(context.Background(), userId)
if err != nil {
return nil, model.NewAppError("GetTeamMembersForUser", "app.team.get_members.app_error", nil, err.Error(), http.StatusInternalServerError)
}

View File

@@ -67,7 +67,10 @@ func (s LocalCacheRoleStore) GetByNames(names []string) ([]*model.Role, error) {
}
}
roles, _ := s.RoleStore.GetByNames(rolesToQuery)
roles, err := s.RoleStore.GetByNames(rolesToQuery)
if err != nil {
return nil, err
}
for _, role := range roles {
s.rootStore.doStandardAddToCache(s.rootStore.roleCache, role.Name, role)

View File

@@ -7113,7 +7113,7 @@ func (s *OpenTracingLayerTeamStore) GetTeamsByUserId(userId string) ([]*model.Te
return result, err
}
func (s *OpenTracingLayerTeamStore) GetTeamsForUser(userId string) ([]*model.TeamMember, error) {
func (s *OpenTracingLayerTeamStore) GetTeamsForUser(ctx context.Context, userId string) ([]*model.TeamMember, error) {
origCtx := s.Root.Store.Context()
span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "TeamStore.GetTeamsForUser")
s.Root.Store.SetContext(newCtx)
@@ -7122,7 +7122,7 @@ func (s *OpenTracingLayerTeamStore) GetTeamsForUser(userId string) ([]*model.Tea
}()
defer span.Finish()
result, err := s.TeamStore.GetTeamsForUser(userId)
result, err := s.TeamStore.GetTeamsForUser(ctx, userId)
if err != nil {
span.LogFields(spanlog.Error(err))
ext.Error.Set(span, true)

View File

@@ -7096,11 +7096,11 @@ func (s *RetryLayerTeamStore) GetTeamsByUserId(userId string) ([]*model.Team, er
}
func (s *RetryLayerTeamStore) GetTeamsForUser(userId string) ([]*model.TeamMember, error) {
func (s *RetryLayerTeamStore) GetTeamsForUser(ctx context.Context, userId string) ([]*model.TeamMember, error) {
tries := 0
for {
result, err := s.TeamStore.GetTeamsForUser(userId)
result, err := s.TeamStore.GetTeamsForUser(ctx, userId)
if err == nil {
return result, nil
}

32
store/sqlstore/context.go Normal file
View File

@@ -0,0 +1,32 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package sqlstore
import "context"
// storeContextKey is the base type for all context keys for the store.
type storeContextKey string
// contextValue is a type to hold some pre-determined context values.
type contextValue string
// Different possible values of contextValue.
const (
useMaster contextValue = "useMaster"
)
// withMaster adds the context value that master DB should be selected for this request.
func withMaster(ctx context.Context) context.Context {
return context.WithValue(ctx, storeContextKey(useMaster), true)
}
// hasMaster is a helper function to check whether master DB should be selected or not.
func hasMaster(ctx context.Context) bool {
if v := ctx.Value(storeContextKey(useMaster)); v != nil {
if res, ok := v.(bool); ok && res {
return true
}
}
return false
}

View File

@@ -0,0 +1,18 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package sqlstore
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestContextMaster(t *testing.T) {
ctx := context.Background()
m := withMaster(ctx)
assert.True(t, hasMaster(m))
}

View File

@@ -4,6 +4,7 @@
package sqlstore
import (
"context"
"fmt"
"time"
@@ -57,7 +58,7 @@ func (me SqlSessionStore) Save(session *model.Session) (*model.Session, error) {
return nil, errors.Wrapf(err, "failed to save Session with id=%s", session.Id)
}
teamMembers, err := me.Team().GetTeamsForUser(session.UserId)
teamMembers, err := me.Team().GetTeamsForUser(context.Background(), session.UserId)
if err != nil {
return nil, errors.Wrapf(err, "failed to find TeamMembers for Session with userId=%s", session.UserId)
}
@@ -82,7 +83,9 @@ func (me SqlSessionStore) Get(sessionIdOrToken string) (*model.Session, error) {
}
session := sessions[0]
tempMembers, err := me.Team().GetTeamsForUser(session.UserId)
tempMembers, err := me.Team().GetTeamsForUser(
withMaster(context.Background()),
session.UserId)
if err != nil {
return nil, errors.Wrapf(err, "failed to find TeamMembers for Session with userId=%s", session.UserId)
}
@@ -102,7 +105,7 @@ func (me SqlSessionStore) GetSessions(userId string) ([]*model.Session, error) {
return nil, errors.Wrapf(err, "failed to find Sessions with userId=%s", userId)
}
teamMembers, err := me.Team().GetTeamsForUser(userId)
teamMembers, err := me.Team().GetTeamsForUser(context.Background(), userId)
if err != nil {
return nil, errors.Wrapf(err, "failed to find TeamMembers for Session with userId=%s", userId)
}

View File

@@ -4,6 +4,7 @@
package sqlstore
import (
"context"
"database/sql"
"fmt"
"strings"
@@ -1146,7 +1147,7 @@ func (s SqlTeamStore) GetMembersByIds(teamId string, userIds []string, restricti
}
// GetTeamsForUser returns a list of teams that the user is a member of. Expects userId to be passed as a parameter.
func (s SqlTeamStore) GetTeamsForUser(userId string) ([]*model.TeamMember, error) {
func (s SqlTeamStore) GetTeamsForUser(ctx context.Context, userId string) ([]*model.TeamMember, error) {
query := s.getTeamMembersWithSchemeSelectQuery().
Where(sq.Eq{"TeamMembers.UserId": userId})
@@ -1156,7 +1157,15 @@ func (s SqlTeamStore) GetTeamsForUser(userId string) ([]*model.TeamMember, error
}
var dbMembers teamMemberWithSchemeRolesList
_, err = s.GetReplica().Select(&dbMembers, queryString, args...)
var db *gorp.DbMap
if hasMaster(ctx) {
db = s.GetMaster()
} else {
db = s.GetReplica()
}
_, err = db.Select(&dbMembers, queryString, args...)
if err != nil {
return nil, errors.Wrapf(err, "failed to find TeamMembers with userId=%s", userId)
}

View File

@@ -103,7 +103,7 @@ type TeamStore interface {
GetMembersByIds(teamId string, userIds []string, restrictions *model.ViewUsersRestrictions) ([]*model.TeamMember, error)
GetTotalMemberCount(teamId string, restrictions *model.ViewUsersRestrictions) (int64, error)
GetActiveMemberCount(teamId string, restrictions *model.ViewUsersRestrictions) (int64, error)
GetTeamsForUser(userId string) ([]*model.TeamMember, error)
GetTeamsForUser(ctx context.Context, userId string) ([]*model.TeamMember, error)
GetTeamsForUserWithPagination(userId string, page, perPage int) ([]*model.TeamMember, error)
GetChannelUnreadsForAllTeams(excludeTeamId, userId string) ([]*model.ChannelUnread, error)
GetChannelUnreadsForTeam(teamId, userId string) ([]*model.ChannelUnread, error)

View File

@@ -5,6 +5,8 @@
package mocks
import (
context "context"
model "github.com/mattermost/mattermost-server/v5/model"
mock "github.com/stretchr/testify/mock"
)
@@ -598,13 +600,13 @@ func (_m *TeamStore) GetTeamsByUserId(userId string) ([]*model.Team, error) {
return r0, r1
}
// GetTeamsForUser provides a mock function with given fields: userId
func (_m *TeamStore) GetTeamsForUser(userId string) ([]*model.TeamMember, error) {
ret := _m.Called(userId)
// GetTeamsForUser provides a mock function with given fields: ctx, userId
func (_m *TeamStore) GetTeamsForUser(ctx context.Context, userId string) ([]*model.TeamMember, error) {
ret := _m.Called(ctx, userId)
var r0 []*model.TeamMember
if rf, ok := ret.Get(0).(func(string) []*model.TeamMember); ok {
r0 = rf(userId)
if rf, ok := ret.Get(0).(func(context.Context, string) []*model.TeamMember); ok {
r0 = rf(ctx, userId)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*model.TeamMember)
@@ -612,8 +614,8 @@ func (_m *TeamStore) GetTeamsForUser(userId string) ([]*model.TeamMember, error)
}
var r1 error
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(userId)
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, userId)
} else {
r1 = ret.Error(1)
}

View File

@@ -4,6 +4,7 @@
package storetest
import (
"context"
"errors"
"strings"
"testing"
@@ -1208,7 +1209,8 @@ func testTeamMembers(t *testing.T, ss store.Store) {
require.Len(t, ms, 1)
require.Equal(t, m3.UserId, ms[0].UserId)
ms, err = ss.Team().GetTeamsForUser(m1.UserId)
ctx := context.Background()
ms, err = ss.Team().GetTeamsForUser(ctx, m1.UserId)
require.Nil(t, err)
require.Len(t, ms, 1)
require.Equal(t, m1.TeamId, ms[0].TeamId)
@@ -1237,14 +1239,14 @@ func testTeamMembers(t *testing.T, ss store.Store) {
_, nErr = ss.Team().SaveMultipleMembers([]*model.TeamMember{m4, m5}, -1)
require.Nil(t, nErr)
ms, err = ss.Team().GetTeamsForUser(uid)
ms, err = ss.Team().GetTeamsForUser(ctx, uid)
require.Nil(t, err)
require.Len(t, ms, 2)
nErr = ss.Team().RemoveAllMembersByUser(uid)
require.Nil(t, nErr)
ms, err = ss.Team().GetTeamsForUser(m1.UserId)
ms, err = ss.Team().GetTeamsForUser(ctx, m1.UserId)
require.Nil(t, err)
require.Empty(t, ms)
}

View File

@@ -6425,10 +6425,10 @@ func (s *TimerLayerTeamStore) GetTeamsByUserId(userId string) ([]*model.Team, er
return result, err
}
func (s *TimerLayerTeamStore) GetTeamsForUser(userId string) ([]*model.TeamMember, error) {
func (s *TimerLayerTeamStore) GetTeamsForUser(ctx context.Context, userId string) ([]*model.TeamMember, error) {
start := timemodule.Now()
result, err := s.TeamStore.GetTeamsForUser(userId)
result, err := s.TeamStore.GetTeamsForUser(ctx, userId)
elapsed := float64(timemodule.Since(start)) / float64(timemodule.Second)
if s.Root.Metrics != nil {