diff --git a/server/channels/api4/websocket.go b/server/channels/api4/websocket.go index 53fc767a67..001bc6a052 100644 --- a/server/channels/api4/websocket.go +++ b/server/channels/api4/websocket.go @@ -11,6 +11,7 @@ import ( "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/mattermost/server/v8/channels/app/platform" + "github.com/mattermost/mattermost/server/v8/channels/web" ) const ( @@ -48,6 +49,13 @@ func connectWebSocket(c *Context, w http.ResponseWriter, r *http.Request) { Locale: "", Active: true, } + // The WebSocket upgrade request coming from mobile is missing the + // user agent so we need to fallback on the session's metadata. + if c.AppContext.Session().IsMobileApp() { + cfg.OriginClient = "mobile" + } else { + cfg.OriginClient = string(web.GetOriginClient(r)) + } cfg.ConnectionID = r.URL.Query().Get(connectionIDParam) if cfg.ConnectionID == "" || c.AppContext.Session().UserId == "" { diff --git a/server/channels/app/platform/web_conn.go b/server/channels/app/platform/web_conn.go index 99f5ca82b3..d44a9b3b23 100644 --- a/server/channels/app/platform/web_conn.go +++ b/server/channels/app/platform/web_conn.go @@ -67,6 +67,7 @@ type WebConnConfig struct { ConnectionID string Active bool ReuseCount int + OriginClient string // These aren't necessary to be exported to api layer. sequence int @@ -115,6 +116,9 @@ type WebConn struct { session atomic.Pointer[model.Session] connectionID atomic.Value + // The client type behind the connection (i.e. web, desktop or mobile) + originClient string + activeChannelID atomic.Value activeTeamID atomic.Value activeRHSThreadChannelID atomic.Value @@ -236,6 +240,7 @@ func (ps *PlatformService) NewWebConn(cfg *WebConnConfig, suite SuiteIFace, runn pluginPosted: make(chan pluginWSPostedHook, 10), lastLogTimeSlow: time.Now(), lastLogTimeFull: time.Now(), + originClient: cfg.OriginClient, } wc.active.Store(cfg.Active) @@ -954,11 +959,13 @@ func (wc *WebConn) logSocketErr(source string, err error) { if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { mlog.Debug(source+": client side closed socket", mlog.String("user_id", wc.UserId), - mlog.String("conn_id", wc.GetConnectionID())) + mlog.String("conn_id", wc.GetConnectionID()), + mlog.String("origin_client", wc.originClient)) } else { mlog.Debug(source+": closing websocket", mlog.String("user_id", wc.UserId), mlog.String("conn_id", wc.GetConnectionID()), + mlog.String("origin_client", wc.originClient), mlog.Err(err)) } } diff --git a/server/channels/app/platform/web_hub.go b/server/channels/app/platform/web_hub.go index 94d5d703ee..9f62534685 100644 --- a/server/channels/app/platform/web_hub.go +++ b/server/channels/app/platform/web_hub.go @@ -391,6 +391,9 @@ func (h *Hub) Start() { connIndex.Add(webConn) atomic.StoreInt64(&h.connectionCount, int64(connIndex.AllActive())) + if metrics := h.platform.metricsIFace; metrics != nil { + metrics.IncrementHTTPWebSockets(webConn.originClient) + } if webConn.IsAuthenticated() && webConn.reuseCount == 0 { // The hello message should only be sent when the reuseCount is 0. @@ -405,6 +408,9 @@ func (h *Hub) Start() { webConn.active.Store(false) atomic.StoreInt64(&h.connectionCount, int64(connIndex.AllActive())) + if metrics := h.platform.metricsIFace; metrics != nil { + metrics.DecrementHTTPWebSockets(webConn.originClient) + } if webConn.UserId == "" { continue diff --git a/server/channels/web/handlers.go b/server/channels/web/handlers.go index f2caca1f83..97d5ad5165 100644 --- a/server/channels/web/handlers.go +++ b/server/channels/web/handlers.go @@ -430,9 +430,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { pageLoadContext = "" } - originClient := string(originClient(r)) - - c.App.Metrics().ObserveAPIEndpointDuration(h.HandlerName, r.Method, statusCode, originClient, pageLoadContext, elapsed) + c.App.Metrics().ObserveAPIEndpointDuration(h.HandlerName, r.Method, statusCode, string(GetOriginClient(r)), pageLoadContext, elapsed) } } } @@ -446,12 +444,12 @@ const ( OriginClientDesktop OriginClient = "desktop" ) -// originClient returns the device from which the provided request was issued. The algorithm roughly looks like: +// GetOriginClient returns the device from which the provided request was issued. The algorithm roughly looks like: // - If the URL contains the query mobilev2=true, then it's mobile // - If the first field of the user agent starts with either "rnbeta" or "Mattermost", then it's mobile // - If the last field of the user agent starts with "Mattermost", then it's desktop // - Otherwise, it's web -func originClient(r *http.Request) OriginClient { +func GetOriginClient(r *http.Request) OriginClient { userAgent := r.Header.Get("User-Agent") fields := strings.Fields(userAgent) if len(fields) < 1 { diff --git a/server/channels/web/handlers_test.go b/server/channels/web/handlers_test.go index 059c612088..12caec6b30 100644 --- a/server/channels/web/handlers_test.go +++ b/server/channels/web/handlers_test.go @@ -883,7 +883,7 @@ func TestCheckCSRFToken(t *testing.T) { }) } -func TestOriginClient(t *testing.T) { +func TestGetOriginClient(t *testing.T) { testCases := []struct { name string userAgent string @@ -945,7 +945,7 @@ func TestOriginClient(t *testing.T) { } // Compute origin client - actualClient := originClient(req) + actualClient := GetOriginClient(req) require.Equal(t, tc.expectedClient, actualClient) } diff --git a/server/einterfaces/metrics.go b/server/einterfaces/metrics.go index 1e2cf9aa49..dac9ca016d 100644 --- a/server/einterfaces/metrics.go +++ b/server/einterfaces/metrics.go @@ -50,6 +50,9 @@ type MetricsInterface interface { DecrementWebSocketBroadcastUsersRegistered(hub string, amount float64) IncrementWebsocketReconnectEvent(eventType string) + IncrementHTTPWebSockets(originClient string) + DecrementHTTPWebSockets(originClient string) + AddMemCacheHitCounter(cacheName string, amount float64) AddMemCacheMissCounter(cacheName string, amount float64) diff --git a/server/einterfaces/mocks/MetricsInterface.go b/server/einterfaces/mocks/MetricsInterface.go index 20120c7cb9..27ecc9ca0f 100644 --- a/server/einterfaces/mocks/MetricsInterface.go +++ b/server/einterfaces/mocks/MetricsInterface.go @@ -28,6 +28,11 @@ func (_m *MetricsInterface) AddMemCacheMissCounter(cacheName string, amount floa _m.Called(cacheName, amount) } +// DecrementHTTPWebSockets provides a mock function with given fields: originClient +func (_m *MetricsInterface) DecrementHTTPWebSockets(originClient string) { + _m.Called(originClient) +} + // DecrementJobActive provides a mock function with given fields: jobType func (_m *MetricsInterface) DecrementJobActive(jobType string) { _m.Called(jobType) @@ -108,6 +113,11 @@ func (_m *MetricsInterface) IncrementHTTPRequest() { _m.Called() } +// IncrementHTTPWebSockets provides a mock function with given fields: originClient +func (_m *MetricsInterface) IncrementHTTPWebSockets(originClient string) { + _m.Called(originClient) +} + // IncrementJobActive provides a mock function with given fields: jobType func (_m *MetricsInterface) IncrementJobActive(jobType string) { _m.Called(jobType) diff --git a/server/enterprise/metrics/metrics.go b/server/enterprise/metrics/metrics.go index b1d7331c8a..06caeac358 100644 --- a/server/enterprise/metrics/metrics.go +++ b/server/enterprise/metrics/metrics.go @@ -66,7 +66,7 @@ type MetricsInterfaceImpl struct { HTTPRequestsCounter prometheus.Counter HTTPErrorsCounter prometheus.Counter - HTTPWebsocketsGauge prometheus.GaugeFunc + HTTPWebsocketsGauge *prometheus.GaugeVec ClusterRequestsDuration prometheus.Histogram ClusterRequestsCounter prometheus.Counter @@ -368,13 +368,13 @@ func New(ps *platform.PlatformService, driver, dataSource string) *MetricsInterf // HTTP Subsystem - m.HTTPWebsocketsGauge = prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + m.HTTPWebsocketsGauge = prometheus.NewGaugeVec(prometheus.GaugeOpts{ Namespace: MetricsNamespace, Subsystem: MetricsSubsystemHTTP, Name: "websockets_total", Help: "The total number of websocket connections to this server.", ConstLabels: additionalLabels, - }, func() float64 { return float64(m.Platform.TotalWebsocketConnections()) }) + }, []string{"origin_client"}) m.Registry.MustRegister(m.HTTPWebsocketsGauge) m.HTTPRequestsCounter = prometheus.NewCounter(prometheus.CounterOpts{ @@ -1467,6 +1467,14 @@ func (mi *MetricsInterfaceImpl) SetReplicaLagTime(node string, value float64) { mi.DbReplicaLagGaugeTime.With(prometheus.Labels{"node": node}).Set(value) } +func (mi *MetricsInterfaceImpl) IncrementHTTPWebSockets(originClient string) { + mi.HTTPWebsocketsGauge.With(prometheus.Labels{"origin_client": originClient}).Inc() +} + +func (mi *MetricsInterfaceImpl) DecrementHTTPWebSockets(originClient string) { + mi.HTTPWebsocketsGauge.With(prometheus.Labels{"origin_client": originClient}).Dec() +} + func extractDBCluster(driver, connectionString string) (string, error) { host, err := extractHost(driver, connectionString) if err != nil {