Files
mattermost/app/web_hub.go
Agniva De Sarker 10f5a8890c MM-24972: make GetHubForUserId lock-less and zero-alloc (#14945)
* MM-24972: make GetHubForUserId lock-less and zero-alloc

The hubs aren't meant to be modified after the server starts up.
But the key problem was the SetHubs method which would zero out
the underlying hubs slice.

This wasn't ideally necessary because an hub would only be stopped
once during server shutdown. However, this was required by a test
which would shutdown the hub twice. And since the hub shutdown isn't
idempotent, zeroing the slice would just skip over shutting down the hub again.

To improve the overall situation, we apply several optimizations.

- We use the new hash/maphash package which exposes Go runtime's
internal hash algorithms to be used as a package. This is much faster
than hash/fnv.
- We move around the initialization of the hub to happen before
the metrics server starts. This allows us to initialize the hub
before any of the hub elements are being accessed.
- To make the test run successfully, we do not call th.TearDown.
This is fine for a test, because anyways the test process would eventually stop
and relinquish the resources to the OS.

This allows us to completely remove any mutexes and thereby
we can remove all the methods and any edge-case checks related to
index being out of bounds. As a result, the fast path becomes
very straightforward and zero-alloc.

name               old time/op    new time/op    delta
GetHubForUserId-8     116ns ± 1%      38ns ± 7%   -67.22%  (p=0.000 n=10+10)

name               old alloc/op   new alloc/op   delta
GetHubForUserId-8     36.0B ± 0%      0.0B       -100.00%  (p=0.000 n=10+10)

name               old allocs/op  new allocs/op  delta
GetHubForUserId-8      2.00 ± 0%      0.00       -100.00%  (p=0.000 n=10+10)

Manually tested with some load testing and running Hub tests in -race mode.

* remove mutex

* incorporate review comments
2020-07-07 11:23:45 +05:30

605 lines
16 KiB
Go

// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package app
import (
"hash/maphash"
"runtime"
"runtime/debug"
"strconv"
"sync/atomic"
"github.com/mattermost/mattermost-server/v5/mlog"
"github.com/mattermost/mattermost-server/v5/model"
)
const (
broadcastQueueSize = 4096
)
type webConnActivityMessage struct {
userId string
sessionToken string
activityAt int64
}
type webConnDirectMessage struct {
conn *WebConn
msg model.WebSocketMessage
}
type webConnSessionMessage struct {
userId string
sessionToken string
isRegistered chan bool
}
// Hub is the central place to manage all websocket connections in the server.
// It handles different websocket events and sending messages to individual
// user connections.
type Hub struct {
// connectionCount should be kept first.
// See https://github.com/mattermost/mattermost-server/pull/7281
connectionCount int64
app *App
connectionIndex int
register chan *WebConn
unregister chan *WebConn
broadcast chan *model.WebSocketEvent
stop chan struct{}
didStop chan struct{}
invalidateUser chan string
activity chan *webConnActivityMessage
directMsg chan *webConnDirectMessage
explicitStop bool
checkRegistered chan *webConnSessionMessage
}
// NewWebHub creates a new Hub.
func (a *App) NewWebHub() *Hub {
return &Hub{
app: a,
register: make(chan *WebConn),
unregister: make(chan *WebConn),
broadcast: make(chan *model.WebSocketEvent, broadcastQueueSize),
stop: make(chan struct{}),
didStop: make(chan struct{}),
invalidateUser: make(chan string),
activity: make(chan *webConnActivityMessage),
directMsg: make(chan *webConnDirectMessage),
checkRegistered: make(chan *webConnSessionMessage),
}
}
func (a *App) TotalWebsocketConnections() int {
return a.Srv().TotalWebsocketConnections()
}
// HubStart starts all the hubs.
func (a *App) HubStart() {
// Total number of hubs is twice the number of CPUs.
numberOfHubs := runtime.NumCPU() * 2
mlog.Info("Starting websocket hubs", mlog.Int("number_of_hubs", numberOfHubs))
hubs := make([]*Hub, numberOfHubs)
for i := 0; i < numberOfHubs; i++ {
hubs[i] = a.NewWebHub()
hubs[i].connectionIndex = i
hubs[i].Start()
}
// Assigning to the hubs slice without any mutex is fine because it is only assigned once
// during the start of the program and always read from after that.
a.srv.hubs = hubs
}
func (a *App) invalidateCacheForUserSkipClusterSend(userId string) {
a.Srv().Store.Channel().InvalidateAllChannelMembersForUser(userId)
a.InvalidateWebConnSessionCacheForUser(userId)
}
func (a *App) invalidateCacheForWebhook(webhookId string) {
a.Srv().Store.Webhook().InvalidateWebhookCache(webhookId)
}
func (a *App) InvalidateWebConnSessionCacheForUser(userId string) {
hub := a.GetHubForUserId(userId)
if hub != nil {
hub.InvalidateUser(userId)
}
}
// HubStop stops all the hubs.
func (s *Server) HubStop() {
mlog.Info("stopping websocket hub connections")
for _, hub := range s.hubs {
hub.Stop()
}
}
func (a *App) HubStop() {
a.Srv().HubStop()
}
// GetHubForUserId returns the hub for a given user id.
func (s *Server) GetHubForUserId(userId string) *Hub {
// TODO: check if caching the userId -> hub mapping
// is worth the memory tradeoff.
// https://mattermost.atlassian.net/browse/MM-26629.
var hash maphash.Hash
hash.SetSeed(s.hashSeed)
hash.Write([]byte(userId))
index := hash.Sum64() % uint64(len(s.hubs))
return s.hubs[int(index)]
}
func (a *App) GetHubForUserId(userId string) *Hub {
return a.Srv().GetHubForUserId(userId)
}
// HubRegister registers a connection to a hub.
func (a *App) HubRegister(webConn *WebConn) {
hub := a.GetHubForUserId(webConn.UserId)
if hub != nil {
if metrics := a.Metrics(); metrics != nil {
metrics.IncrementWebSocketBroadcastUsersRegistered(strconv.Itoa(hub.connectionIndex), 1)
}
hub.Register(webConn)
}
}
// HubUnregister unregisters a connection from a hub.
func (a *App) HubUnregister(webConn *WebConn) {
hub := a.GetHubForUserId(webConn.UserId)
if hub != nil {
if metrics := a.Metrics(); metrics != nil {
metrics.DecrementWebSocketBroadcastUsersRegistered(strconv.Itoa(hub.connectionIndex), 1)
}
hub.Unregister(webConn)
}
}
func (s *Server) Publish(message *model.WebSocketEvent) {
if s.Metrics != nil {
s.Metrics.IncrementWebsocketEvent(message.EventType())
}
s.PublishSkipClusterSend(message)
if s.Cluster != nil {
cm := &model.ClusterMessage{
Event: model.CLUSTER_EVENT_PUBLISH,
SendType: model.CLUSTER_SEND_BEST_EFFORT,
Data: message.ToJson(),
}
if message.EventType() == model.WEBSOCKET_EVENT_POSTED ||
message.EventType() == model.WEBSOCKET_EVENT_POST_EDITED ||
message.EventType() == model.WEBSOCKET_EVENT_DIRECT_ADDED ||
message.EventType() == model.WEBSOCKET_EVENT_GROUP_ADDED ||
message.EventType() == model.WEBSOCKET_EVENT_ADDED_TO_TEAM {
cm.SendType = model.CLUSTER_SEND_RELIABLE
}
s.Cluster.SendClusterMessage(cm)
}
}
func (a *App) Publish(message *model.WebSocketEvent) {
a.Srv().Publish(message)
}
func (s *Server) PublishSkipClusterSend(message *model.WebSocketEvent) {
if message.GetBroadcast().UserId != "" {
hub := s.GetHubForUserId(message.GetBroadcast().UserId)
if hub != nil {
hub.Broadcast(message)
}
} else {
for _, hub := range s.hubs {
hub.Broadcast(message)
}
}
}
func (a *App) PublishSkipClusterSend(message *model.WebSocketEvent) {
a.Srv().PublishSkipClusterSend(message)
}
func (a *App) invalidateCacheForChannel(channel *model.Channel) {
a.Srv().Store.Channel().InvalidateChannel(channel.Id)
a.invalidateCacheForChannelByNameSkipClusterSend(channel.TeamId, channel.Name)
if a.Cluster() != nil {
nameMsg := &model.ClusterMessage{
Event: model.CLUSTER_EVENT_INVALIDATE_CACHE_FOR_CHANNEL_BY_NAME,
SendType: model.CLUSTER_SEND_BEST_EFFORT,
Props: make(map[string]string),
}
nameMsg.Props["name"] = channel.Name
if channel.TeamId == "" {
nameMsg.Props["id"] = "dm"
} else {
nameMsg.Props["id"] = channel.TeamId
}
a.Cluster().SendClusterMessage(nameMsg)
}
}
func (a *App) invalidateCacheForChannelMembers(channelId string) {
a.Srv().Store.User().InvalidateProfilesInChannelCache(channelId)
a.Srv().Store.Channel().InvalidateMemberCount(channelId)
a.Srv().Store.Channel().InvalidateGuestCount(channelId)
}
func (a *App) invalidateCacheForChannelMembersNotifyProps(channelId string) {
a.invalidateCacheForChannelMembersNotifyPropsSkipClusterSend(channelId)
if a.Cluster() != nil {
msg := &model.ClusterMessage{
Event: model.CLUSTER_EVENT_INVALIDATE_CACHE_FOR_CHANNEL_MEMBERS_NOTIFY_PROPS,
SendType: model.CLUSTER_SEND_BEST_EFFORT,
Data: channelId,
}
a.Cluster().SendClusterMessage(msg)
}
}
func (a *App) invalidateCacheForChannelMembersNotifyPropsSkipClusterSend(channelId string) {
a.Srv().Store.Channel().InvalidateCacheForChannelMembersNotifyProps(channelId)
}
func (a *App) invalidateCacheForChannelByNameSkipClusterSend(teamId, name string) {
if teamId == "" {
teamId = "dm"
}
a.Srv().Store.Channel().InvalidateChannelByName(teamId, name)
}
func (a *App) invalidateCacheForChannelPosts(channelId string) {
a.Srv().Store.Channel().InvalidatePinnedPostCount(channelId)
a.Srv().Store.Post().InvalidateLastPostTimeCache(channelId)
}
func (a *App) InvalidateCacheForUser(userId string) {
a.invalidateCacheForUserSkipClusterSend(userId)
a.Srv().Store.User().InvalidateProfilesInChannelCacheByUser(userId)
a.Srv().Store.User().InvalidateProfileCacheForUser(userId)
if a.Cluster() != nil {
msg := &model.ClusterMessage{
Event: model.CLUSTER_EVENT_INVALIDATE_CACHE_FOR_USER,
SendType: model.CLUSTER_SEND_BEST_EFFORT,
Data: userId,
}
a.Cluster().SendClusterMessage(msg)
}
}
func (a *App) invalidateCacheForUserTeams(userId string) {
a.InvalidateWebConnSessionCacheForUser(userId)
a.Srv().Store.Team().InvalidateAllTeamIdsForUser(userId)
if a.Cluster() != nil {
msg := &model.ClusterMessage{
Event: model.CLUSTER_EVENT_INVALIDATE_CACHE_FOR_USER_TEAMS,
SendType: model.CLUSTER_SEND_BEST_EFFORT,
Data: userId,
}
a.Cluster().SendClusterMessage(msg)
}
}
// UpdateWebConnUserActivity sets the LastUserActivityAt of the hub for the given session.
func (a *App) UpdateWebConnUserActivity(session model.Session, activityAt int64) {
hub := a.GetHubForUserId(session.UserId)
if hub != nil {
hub.UpdateActivity(session.UserId, session.Token, activityAt)
}
}
// SessionIsRegistered determines if a specific session has been registered
func (a *App) SessionIsRegistered(session model.Session) bool {
hub := a.GetHubForUserId(session.UserId)
if hub != nil {
return hub.IsRegistered(session.UserId, session.Token)
}
return false
}
// Register registers a connection to the hub.
func (h *Hub) Register(webConn *WebConn) {
select {
case h.register <- webConn:
case <-h.stop:
}
}
// Unregister unregisters a connection from the hub.
func (h *Hub) Unregister(webConn *WebConn) {
select {
case h.unregister <- webConn:
case <-h.stop:
}
}
// Determines if a user's session is registered a connection from the hub.
func (h *Hub) IsRegistered(userId, sessionToken string) bool {
ws := &webConnSessionMessage{
userId: userId,
sessionToken: sessionToken,
isRegistered: make(chan bool),
}
select {
case h.checkRegistered <- ws:
return <-ws.isRegistered
case <-h.stop:
}
return false
}
// Broadcast broadcasts the message to all connections in the hub.
func (h *Hub) Broadcast(message *model.WebSocketEvent) {
// XXX: The hub nil check is because of the way we setup our tests. We call
// `app.NewServer()` which returns a server, but only after that, we call
// `wsapi.Init()` to initialize the hub. But in the `NewServer` call
// itself proceeds to broadcast some messages happily. This needs to be
// fixed once the the wsapi cyclic dependency with server/app goes away.
// And possibly, we can look into doing the hub initialization inside
// NewServer itself.
if h != nil && message != nil {
if metrics := h.app.Metrics(); metrics != nil {
metrics.IncrementWebSocketBroadcastBufferSize(strconv.Itoa(h.connectionIndex), 1)
}
select {
case h.broadcast <- message:
case <-h.stop:
}
}
}
// InvalidateUser invalidates the cache for the given user.
func (h *Hub) InvalidateUser(userId string) {
select {
case h.invalidateUser <- userId:
case <-h.stop:
}
}
// UpdateActivity sets the LastUserActivityAt field for the connection
// of the user.
func (h *Hub) UpdateActivity(userId, sessionToken string, activityAt int64) {
select {
case h.activity <- &webConnActivityMessage{
userId: userId,
sessionToken: sessionToken,
activityAt: activityAt,
}:
case <-h.stop:
}
}
// SendMessage sends the given message to the given connection.
func (h *Hub) SendMessage(conn *WebConn, msg model.WebSocketMessage) {
select {
case h.directMsg <- &webConnDirectMessage{
conn: conn,
msg: msg,
}:
case <-h.stop:
}
}
// Stop stops the hub.
func (h *Hub) Stop() {
close(h.stop)
<-h.didStop
}
// Start starts the hub.
func (h *Hub) Start() {
var doStart func()
var doRecoverableStart func()
var doRecover func()
doStart = func() {
mlog.Debug("Hub is starting", mlog.Int("index", h.connectionIndex))
connIndex := newHubConnectionIndex()
for {
select {
case webSessionMessage := <-h.checkRegistered:
conns := connIndex.ForUser(webSessionMessage.userId)
var isRegistered bool
for _, conn := range conns {
if conn.GetSessionToken() == webSessionMessage.sessionToken {
isRegistered = true
}
}
webSessionMessage.isRegistered <- isRegistered
case webConn := <-h.register:
connIndex.Add(webConn)
atomic.StoreInt64(&h.connectionCount, int64(len(connIndex.All())))
if webConn.IsAuthenticated() {
webConn.send <- webConn.createHelloMessage()
}
case webConn := <-h.unregister:
connIndex.Remove(webConn)
atomic.StoreInt64(&h.connectionCount, int64(len(connIndex.All())))
if len(webConn.UserId) == 0 {
continue
}
conns := connIndex.ForUser(webConn.UserId)
if len(conns) == 0 {
h.app.Srv().Go(func() {
h.app.SetStatusOffline(webConn.UserId, false)
})
continue
}
var latestActivity int64 = 0
for _, conn := range conns {
if conn.lastUserActivityAt > latestActivity {
latestActivity = conn.lastUserActivityAt
}
}
if h.app.IsUserAway(latestActivity) {
h.app.Srv().Go(func() {
h.app.SetStatusLastActivityAt(webConn.UserId, latestActivity)
})
}
case userId := <-h.invalidateUser:
for _, webConn := range connIndex.ForUser(userId) {
webConn.InvalidateCache()
}
case activity := <-h.activity:
for _, webConn := range connIndex.ForUser(activity.userId) {
if webConn.GetSessionToken() == activity.sessionToken {
webConn.lastUserActivityAt = activity.activityAt
}
}
case directMsg := <-h.directMsg:
if !connIndex.Has(directMsg.conn) {
continue
}
select {
case directMsg.conn.send <- directMsg.msg:
default:
mlog.Error("webhub.broadcast: cannot send, closing websocket for user", mlog.String("user_id", directMsg.conn.UserId))
close(directMsg.conn.send)
connIndex.Remove(directMsg.conn)
}
case msg := <-h.broadcast:
if metrics := h.app.Metrics(); metrics != nil {
metrics.DecrementWebSocketBroadcastBufferSize(strconv.Itoa(h.connectionIndex), 1)
}
msg = msg.PrecomputeJSON()
broadcast := func(webConn *WebConn) {
if !connIndex.Has(webConn) {
return
}
if webConn.shouldSendEvent(msg) {
select {
case webConn.send <- msg:
default:
mlog.Error("webhub.broadcast: cannot send, closing websocket for user", mlog.String("user_id", webConn.UserId))
close(webConn.send)
connIndex.Remove(webConn)
}
}
}
if msg.GetBroadcast().UserId != "" {
candidates := connIndex.ForUser(msg.GetBroadcast().UserId)
for _, webConn := range candidates {
broadcast(webConn)
}
continue
}
candidates := connIndex.All()
for webConn := range candidates {
broadcast(webConn)
}
case <-h.stop:
for webConn := range connIndex.All() {
webConn.Close()
h.app.SetStatusOffline(webConn.UserId, false)
}
h.explicitStop = true
close(h.didStop)
return
}
}
}
doRecoverableStart = func() {
defer doRecover()
doStart()
}
doRecover = func() {
if !h.explicitStop {
if r := recover(); r != nil {
mlog.Error("Recovering from Hub panic.", mlog.Any("panic", r))
} else {
mlog.Error("Webhub stopped unexpectedly. Recovering.")
}
mlog.Error(string(debug.Stack()))
go doRecoverableStart()
}
}
go doRecoverableStart()
}
// hubConnectionIndex provides fast addition, removal, and iteration of web connections.
// It requires 3 functionalities which need to be very fast:
// - check if a connection exists or not.
// - get all connections for a given userID.
// - get all connections.
type hubConnectionIndex struct {
// byUserId stores the list of connections for a given userID
byUserId map[string][]*WebConn
// byConnection serves the dual purpose of storing the index of the webconn
// in the value of byUserId map, and also to get all connections.
byConnection map[*WebConn]int
}
func newHubConnectionIndex() *hubConnectionIndex {
return &hubConnectionIndex{
byUserId: make(map[string][]*WebConn),
byConnection: make(map[*WebConn]int),
}
}
func (i *hubConnectionIndex) Add(wc *WebConn) {
i.byUserId[wc.UserId] = append(i.byUserId[wc.UserId], wc)
i.byConnection[wc] = len(i.byUserId[wc.UserId]) - 1
}
func (i *hubConnectionIndex) Remove(wc *WebConn) {
userConnIndex, ok := i.byConnection[wc]
if !ok {
return
}
// get the conn slice.
userConnections := i.byUserId[wc.UserId]
// get the last connection.
last := userConnections[len(userConnections)-1]
// set the slot that we are trying to remove to be the last connection.
userConnections[userConnIndex] = last
// remove the last connection from the slice.
i.byUserId[wc.UserId] = userConnections[:len(userConnections)-1]
// set the index of the connection that was moved to the new index.
i.byConnection[last] = userConnIndex
delete(i.byConnection, wc)
}
func (i *hubConnectionIndex) Has(wc *WebConn) bool {
_, ok := i.byConnection[wc]
return ok
}
func (i *hubConnectionIndex) ForUser(id string) []*WebConn {
return i.byUserId[id]
}
func (i *hubConnectionIndex) All() map[*WebConn]int {
return i.byConnection
}