[MM-57010] Include client type in websocket connections metric (#26763)

* Include client type in websocket connections metric

* Unexport field
This commit is contained in:
Claudio Costa
2024-04-16 10:49:49 -06:00
committed by GitHub
parent effb99301e
commit 4b508eed46
8 changed files with 51 additions and 11 deletions

View File

@@ -11,6 +11,7 @@ import (
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/mlog"
"github.com/mattermost/mattermost/server/v8/channels/app/platform"
"github.com/mattermost/mattermost/server/v8/channels/web"
)
const (
@@ -48,6 +49,13 @@ func connectWebSocket(c *Context, w http.ResponseWriter, r *http.Request) {
Locale: "",
Active: true,
}
// The WebSocket upgrade request coming from mobile is missing the
// user agent so we need to fallback on the session's metadata.
if c.AppContext.Session().IsMobileApp() {
cfg.OriginClient = "mobile"
} else {
cfg.OriginClient = string(web.GetOriginClient(r))
}
cfg.ConnectionID = r.URL.Query().Get(connectionIDParam)
if cfg.ConnectionID == "" || c.AppContext.Session().UserId == "" {

View File

@@ -67,6 +67,7 @@ type WebConnConfig struct {
ConnectionID string
Active bool
ReuseCount int
OriginClient string
// These aren't necessary to be exported to api layer.
sequence int
@@ -115,6 +116,9 @@ type WebConn struct {
session atomic.Pointer[model.Session]
connectionID atomic.Value
// The client type behind the connection (i.e. web, desktop or mobile)
originClient string
activeChannelID atomic.Value
activeTeamID atomic.Value
activeRHSThreadChannelID atomic.Value
@@ -236,6 +240,7 @@ func (ps *PlatformService) NewWebConn(cfg *WebConnConfig, suite SuiteIFace, runn
pluginPosted: make(chan pluginWSPostedHook, 10),
lastLogTimeSlow: time.Now(),
lastLogTimeFull: time.Now(),
originClient: cfg.OriginClient,
}
wc.active.Store(cfg.Active)
@@ -954,11 +959,13 @@ func (wc *WebConn) logSocketErr(source string, err error) {
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
mlog.Debug(source+": client side closed socket",
mlog.String("user_id", wc.UserId),
mlog.String("conn_id", wc.GetConnectionID()))
mlog.String("conn_id", wc.GetConnectionID()),
mlog.String("origin_client", wc.originClient))
} else {
mlog.Debug(source+": closing websocket",
mlog.String("user_id", wc.UserId),
mlog.String("conn_id", wc.GetConnectionID()),
mlog.String("origin_client", wc.originClient),
mlog.Err(err))
}
}

View File

@@ -391,6 +391,9 @@ func (h *Hub) Start() {
connIndex.Add(webConn)
atomic.StoreInt64(&h.connectionCount, int64(connIndex.AllActive()))
if metrics := h.platform.metricsIFace; metrics != nil {
metrics.IncrementHTTPWebSockets(webConn.originClient)
}
if webConn.IsAuthenticated() && webConn.reuseCount == 0 {
// The hello message should only be sent when the reuseCount is 0.
@@ -405,6 +408,9 @@ func (h *Hub) Start() {
webConn.active.Store(false)
atomic.StoreInt64(&h.connectionCount, int64(connIndex.AllActive()))
if metrics := h.platform.metricsIFace; metrics != nil {
metrics.DecrementHTTPWebSockets(webConn.originClient)
}
if webConn.UserId == "" {
continue

View File

@@ -430,9 +430,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
pageLoadContext = ""
}
originClient := string(originClient(r))
c.App.Metrics().ObserveAPIEndpointDuration(h.HandlerName, r.Method, statusCode, originClient, pageLoadContext, elapsed)
c.App.Metrics().ObserveAPIEndpointDuration(h.HandlerName, r.Method, statusCode, string(GetOriginClient(r)), pageLoadContext, elapsed)
}
}
}
@@ -446,12 +444,12 @@ const (
OriginClientDesktop OriginClient = "desktop"
)
// originClient returns the device from which the provided request was issued. The algorithm roughly looks like:
// GetOriginClient returns the device from which the provided request was issued. The algorithm roughly looks like:
// - If the URL contains the query mobilev2=true, then it's mobile
// - If the first field of the user agent starts with either "rnbeta" or "Mattermost", then it's mobile
// - If the last field of the user agent starts with "Mattermost", then it's desktop
// - Otherwise, it's web
func originClient(r *http.Request) OriginClient {
func GetOriginClient(r *http.Request) OriginClient {
userAgent := r.Header.Get("User-Agent")
fields := strings.Fields(userAgent)
if len(fields) < 1 {

View File

@@ -883,7 +883,7 @@ func TestCheckCSRFToken(t *testing.T) {
})
}
func TestOriginClient(t *testing.T) {
func TestGetOriginClient(t *testing.T) {
testCases := []struct {
name string
userAgent string
@@ -945,7 +945,7 @@ func TestOriginClient(t *testing.T) {
}
// Compute origin client
actualClient := originClient(req)
actualClient := GetOriginClient(req)
require.Equal(t, tc.expectedClient, actualClient)
}

View File

@@ -50,6 +50,9 @@ type MetricsInterface interface {
DecrementWebSocketBroadcastUsersRegistered(hub string, amount float64)
IncrementWebsocketReconnectEvent(eventType string)
IncrementHTTPWebSockets(originClient string)
DecrementHTTPWebSockets(originClient string)
AddMemCacheHitCounter(cacheName string, amount float64)
AddMemCacheMissCounter(cacheName string, amount float64)

View File

@@ -28,6 +28,11 @@ func (_m *MetricsInterface) AddMemCacheMissCounter(cacheName string, amount floa
_m.Called(cacheName, amount)
}
// DecrementHTTPWebSockets provides a mock function with given fields: originClient
func (_m *MetricsInterface) DecrementHTTPWebSockets(originClient string) {
_m.Called(originClient)
}
// DecrementJobActive provides a mock function with given fields: jobType
func (_m *MetricsInterface) DecrementJobActive(jobType string) {
_m.Called(jobType)
@@ -108,6 +113,11 @@ func (_m *MetricsInterface) IncrementHTTPRequest() {
_m.Called()
}
// IncrementHTTPWebSockets provides a mock function with given fields: originClient
func (_m *MetricsInterface) IncrementHTTPWebSockets(originClient string) {
_m.Called(originClient)
}
// IncrementJobActive provides a mock function with given fields: jobType
func (_m *MetricsInterface) IncrementJobActive(jobType string) {
_m.Called(jobType)

View File

@@ -66,7 +66,7 @@ type MetricsInterfaceImpl struct {
HTTPRequestsCounter prometheus.Counter
HTTPErrorsCounter prometheus.Counter
HTTPWebsocketsGauge prometheus.GaugeFunc
HTTPWebsocketsGauge *prometheus.GaugeVec
ClusterRequestsDuration prometheus.Histogram
ClusterRequestsCounter prometheus.Counter
@@ -368,13 +368,13 @@ func New(ps *platform.PlatformService, driver, dataSource string) *MetricsInterf
// HTTP Subsystem
m.HTTPWebsocketsGauge = prometheus.NewGaugeFunc(prometheus.GaugeOpts{
m.HTTPWebsocketsGauge = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: MetricsSubsystemHTTP,
Name: "websockets_total",
Help: "The total number of websocket connections to this server.",
ConstLabels: additionalLabels,
}, func() float64 { return float64(m.Platform.TotalWebsocketConnections()) })
}, []string{"origin_client"})
m.Registry.MustRegister(m.HTTPWebsocketsGauge)
m.HTTPRequestsCounter = prometheus.NewCounter(prometheus.CounterOpts{
@@ -1467,6 +1467,14 @@ func (mi *MetricsInterfaceImpl) SetReplicaLagTime(node string, value float64) {
mi.DbReplicaLagGaugeTime.With(prometheus.Labels{"node": node}).Set(value)
}
func (mi *MetricsInterfaceImpl) IncrementHTTPWebSockets(originClient string) {
mi.HTTPWebsocketsGauge.With(prometheus.Labels{"origin_client": originClient}).Inc()
}
func (mi *MetricsInterfaceImpl) DecrementHTTPWebSockets(originClient string) {
mi.HTTPWebsocketsGauge.With(prometheus.Labels{"origin_client": originClient}).Dec()
}
func extractDBCluster(driver, connectionString string) (string, error) {
host, err := extractHost(driver, connectionString)
if err != nil {