mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
[MM-57010] Include client type in websocket connections metric (#26763)
* Include client type in websocket connections metric * Unexport field
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/mattermost/mattermost/server/public/model"
|
||||
"github.com/mattermost/mattermost/server/public/shared/mlog"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/app/platform"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/web"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -48,6 +49,13 @@ func connectWebSocket(c *Context, w http.ResponseWriter, r *http.Request) {
|
||||
Locale: "",
|
||||
Active: true,
|
||||
}
|
||||
// The WebSocket upgrade request coming from mobile is missing the
|
||||
// user agent so we need to fallback on the session's metadata.
|
||||
if c.AppContext.Session().IsMobileApp() {
|
||||
cfg.OriginClient = "mobile"
|
||||
} else {
|
||||
cfg.OriginClient = string(web.GetOriginClient(r))
|
||||
}
|
||||
|
||||
cfg.ConnectionID = r.URL.Query().Get(connectionIDParam)
|
||||
if cfg.ConnectionID == "" || c.AppContext.Session().UserId == "" {
|
||||
|
||||
@@ -67,6 +67,7 @@ type WebConnConfig struct {
|
||||
ConnectionID string
|
||||
Active bool
|
||||
ReuseCount int
|
||||
OriginClient string
|
||||
|
||||
// These aren't necessary to be exported to api layer.
|
||||
sequence int
|
||||
@@ -115,6 +116,9 @@ type WebConn struct {
|
||||
session atomic.Pointer[model.Session]
|
||||
connectionID atomic.Value
|
||||
|
||||
// The client type behind the connection (i.e. web, desktop or mobile)
|
||||
originClient string
|
||||
|
||||
activeChannelID atomic.Value
|
||||
activeTeamID atomic.Value
|
||||
activeRHSThreadChannelID atomic.Value
|
||||
@@ -236,6 +240,7 @@ func (ps *PlatformService) NewWebConn(cfg *WebConnConfig, suite SuiteIFace, runn
|
||||
pluginPosted: make(chan pluginWSPostedHook, 10),
|
||||
lastLogTimeSlow: time.Now(),
|
||||
lastLogTimeFull: time.Now(),
|
||||
originClient: cfg.OriginClient,
|
||||
}
|
||||
wc.active.Store(cfg.Active)
|
||||
|
||||
@@ -954,11 +959,13 @@ func (wc *WebConn) logSocketErr(source string, err error) {
|
||||
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
|
||||
mlog.Debug(source+": client side closed socket",
|
||||
mlog.String("user_id", wc.UserId),
|
||||
mlog.String("conn_id", wc.GetConnectionID()))
|
||||
mlog.String("conn_id", wc.GetConnectionID()),
|
||||
mlog.String("origin_client", wc.originClient))
|
||||
} else {
|
||||
mlog.Debug(source+": closing websocket",
|
||||
mlog.String("user_id", wc.UserId),
|
||||
mlog.String("conn_id", wc.GetConnectionID()),
|
||||
mlog.String("origin_client", wc.originClient),
|
||||
mlog.Err(err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -391,6 +391,9 @@ func (h *Hub) Start() {
|
||||
|
||||
connIndex.Add(webConn)
|
||||
atomic.StoreInt64(&h.connectionCount, int64(connIndex.AllActive()))
|
||||
if metrics := h.platform.metricsIFace; metrics != nil {
|
||||
metrics.IncrementHTTPWebSockets(webConn.originClient)
|
||||
}
|
||||
|
||||
if webConn.IsAuthenticated() && webConn.reuseCount == 0 {
|
||||
// The hello message should only be sent when the reuseCount is 0.
|
||||
@@ -405,6 +408,9 @@ func (h *Hub) Start() {
|
||||
webConn.active.Store(false)
|
||||
|
||||
atomic.StoreInt64(&h.connectionCount, int64(connIndex.AllActive()))
|
||||
if metrics := h.platform.metricsIFace; metrics != nil {
|
||||
metrics.DecrementHTTPWebSockets(webConn.originClient)
|
||||
}
|
||||
|
||||
if webConn.UserId == "" {
|
||||
continue
|
||||
|
||||
@@ -430,9 +430,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
pageLoadContext = ""
|
||||
}
|
||||
|
||||
originClient := string(originClient(r))
|
||||
|
||||
c.App.Metrics().ObserveAPIEndpointDuration(h.HandlerName, r.Method, statusCode, originClient, pageLoadContext, elapsed)
|
||||
c.App.Metrics().ObserveAPIEndpointDuration(h.HandlerName, r.Method, statusCode, string(GetOriginClient(r)), pageLoadContext, elapsed)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -446,12 +444,12 @@ const (
|
||||
OriginClientDesktop OriginClient = "desktop"
|
||||
)
|
||||
|
||||
// originClient returns the device from which the provided request was issued. The algorithm roughly looks like:
|
||||
// GetOriginClient returns the device from which the provided request was issued. The algorithm roughly looks like:
|
||||
// - If the URL contains the query mobilev2=true, then it's mobile
|
||||
// - If the first field of the user agent starts with either "rnbeta" or "Mattermost", then it's mobile
|
||||
// - If the last field of the user agent starts with "Mattermost", then it's desktop
|
||||
// - Otherwise, it's web
|
||||
func originClient(r *http.Request) OriginClient {
|
||||
func GetOriginClient(r *http.Request) OriginClient {
|
||||
userAgent := r.Header.Get("User-Agent")
|
||||
fields := strings.Fields(userAgent)
|
||||
if len(fields) < 1 {
|
||||
|
||||
@@ -883,7 +883,7 @@ func TestCheckCSRFToken(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestOriginClient(t *testing.T) {
|
||||
func TestGetOriginClient(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
userAgent string
|
||||
@@ -945,7 +945,7 @@ func TestOriginClient(t *testing.T) {
|
||||
}
|
||||
|
||||
// Compute origin client
|
||||
actualClient := originClient(req)
|
||||
actualClient := GetOriginClient(req)
|
||||
|
||||
require.Equal(t, tc.expectedClient, actualClient)
|
||||
}
|
||||
|
||||
@@ -50,6 +50,9 @@ type MetricsInterface interface {
|
||||
DecrementWebSocketBroadcastUsersRegistered(hub string, amount float64)
|
||||
IncrementWebsocketReconnectEvent(eventType string)
|
||||
|
||||
IncrementHTTPWebSockets(originClient string)
|
||||
DecrementHTTPWebSockets(originClient string)
|
||||
|
||||
AddMemCacheHitCounter(cacheName string, amount float64)
|
||||
AddMemCacheMissCounter(cacheName string, amount float64)
|
||||
|
||||
|
||||
@@ -28,6 +28,11 @@ func (_m *MetricsInterface) AddMemCacheMissCounter(cacheName string, amount floa
|
||||
_m.Called(cacheName, amount)
|
||||
}
|
||||
|
||||
// DecrementHTTPWebSockets provides a mock function with given fields: originClient
|
||||
func (_m *MetricsInterface) DecrementHTTPWebSockets(originClient string) {
|
||||
_m.Called(originClient)
|
||||
}
|
||||
|
||||
// DecrementJobActive provides a mock function with given fields: jobType
|
||||
func (_m *MetricsInterface) DecrementJobActive(jobType string) {
|
||||
_m.Called(jobType)
|
||||
@@ -108,6 +113,11 @@ func (_m *MetricsInterface) IncrementHTTPRequest() {
|
||||
_m.Called()
|
||||
}
|
||||
|
||||
// IncrementHTTPWebSockets provides a mock function with given fields: originClient
|
||||
func (_m *MetricsInterface) IncrementHTTPWebSockets(originClient string) {
|
||||
_m.Called(originClient)
|
||||
}
|
||||
|
||||
// IncrementJobActive provides a mock function with given fields: jobType
|
||||
func (_m *MetricsInterface) IncrementJobActive(jobType string) {
|
||||
_m.Called(jobType)
|
||||
|
||||
@@ -66,7 +66,7 @@ type MetricsInterfaceImpl struct {
|
||||
|
||||
HTTPRequestsCounter prometheus.Counter
|
||||
HTTPErrorsCounter prometheus.Counter
|
||||
HTTPWebsocketsGauge prometheus.GaugeFunc
|
||||
HTTPWebsocketsGauge *prometheus.GaugeVec
|
||||
|
||||
ClusterRequestsDuration prometheus.Histogram
|
||||
ClusterRequestsCounter prometheus.Counter
|
||||
@@ -368,13 +368,13 @@ func New(ps *platform.PlatformService, driver, dataSource string) *MetricsInterf
|
||||
|
||||
// HTTP Subsystem
|
||||
|
||||
m.HTTPWebsocketsGauge = prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
||||
m.HTTPWebsocketsGauge = prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
Subsystem: MetricsSubsystemHTTP,
|
||||
Name: "websockets_total",
|
||||
Help: "The total number of websocket connections to this server.",
|
||||
ConstLabels: additionalLabels,
|
||||
}, func() float64 { return float64(m.Platform.TotalWebsocketConnections()) })
|
||||
}, []string{"origin_client"})
|
||||
m.Registry.MustRegister(m.HTTPWebsocketsGauge)
|
||||
|
||||
m.HTTPRequestsCounter = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
@@ -1467,6 +1467,14 @@ func (mi *MetricsInterfaceImpl) SetReplicaLagTime(node string, value float64) {
|
||||
mi.DbReplicaLagGaugeTime.With(prometheus.Labels{"node": node}).Set(value)
|
||||
}
|
||||
|
||||
func (mi *MetricsInterfaceImpl) IncrementHTTPWebSockets(originClient string) {
|
||||
mi.HTTPWebsocketsGauge.With(prometheus.Labels{"origin_client": originClient}).Inc()
|
||||
}
|
||||
|
||||
func (mi *MetricsInterfaceImpl) DecrementHTTPWebSockets(originClient string) {
|
||||
mi.HTTPWebsocketsGauge.With(prometheus.Labels{"origin_client": originClient}).Dec()
|
||||
}
|
||||
|
||||
func extractDBCluster(driver, connectionString string) (string, error) {
|
||||
host, err := extractHost(driver, connectionString)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user