UsageStatsService: Don't use global state (#31849)

* UsageStatsService: Don't use global state

Signed-off-by: Arve Knudsen <arve.knudsen@gmail.com>
This commit is contained in:
Arve Knudsen 2021-03-10 10:14:00 +01:00 committed by GitHub
parent 598a44076a
commit 04e9f6c24f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 254 additions and 270 deletions

View File

@ -27,7 +27,6 @@ func init() {
type UsageStats interface {
GetUsageReport(ctx context.Context) (UsageReport, error)
RegisterMetric(name string, fn MetricFunc)
}

View File

@ -13,7 +13,6 @@ import (
"github.com/grafana/grafana/pkg/infra/metrics"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins/manager"
"github.com/grafana/grafana/pkg/setting"
)
var usageStatsURL = "https://stats.grafana.org/grafana-usage-report"
@ -29,18 +28,22 @@ type UsageReport struct {
}
func (uss *UsageStatsService) GetUsageReport(ctx context.Context) (UsageReport, error) {
version := strings.ReplaceAll(setting.BuildVersion, ".", "_")
version := strings.ReplaceAll(uss.Cfg.BuildVersion, ".", "_")
metrics := map[string]interface{}{}
edition := "oss"
if uss.Cfg.IsEnterprise {
edition = "enterprise"
}
report := UsageReport{
Version: version,
Metrics: metrics,
Os: runtime.GOOS,
Arch: runtime.GOARCH,
Edition: getEdition(),
Edition: edition,
HasValidLicense: uss.License.HasValidLicense(),
Packaging: setting.Packaging,
Packaging: uss.Cfg.Packaging,
}
statsQuery := models.GetSystemStatsQuery{}
@ -69,9 +72,19 @@ func (uss *UsageStatsService) GetUsageReport(ctx context.Context) (UsageReport,
metrics["stats.total_auth_token.count"] = statsQuery.Result.AuthTokens
metrics["stats.dashboard_versions.count"] = statsQuery.Result.DashboardVersions
metrics["stats.annotations.count"] = statsQuery.Result.Annotations
metrics["stats.valid_license.count"] = getValidLicenseCount(uss.License.HasValidLicense())
metrics["stats.edition.oss.count"] = getOssEditionCount()
metrics["stats.edition.enterprise.count"] = getEnterpriseEditionCount()
validLicCount := 0
if uss.License.HasValidLicense() {
validLicCount = 1
}
metrics["stats.valid_license.count"] = validLicCount
ossEditionCount := 1
enterpriseEditionCount := 0
if uss.Cfg.IsEnterprise {
enterpriseEditionCount = 1
ossEditionCount = 0
}
metrics["stats.edition.oss.count"] = ossEditionCount
metrics["stats.edition.enterprise.count"] = enterpriseEditionCount
uss.registerExternalMetrics(metrics)
@ -102,8 +115,8 @@ func (uss *UsageStatsService) GetUsageReport(ctx context.Context) (UsageReport,
}
metrics["stats.ds.other.count"] = dsOtherCount
metrics["stats.packaging."+setting.Packaging+".count"] = 1
metrics["stats.distributor."+setting.ReportingDistributor+".count"] = 1
metrics["stats.packaging."+uss.Cfg.Packaging+".count"] = 1
metrics["stats.distributor."+uss.Cfg.ReportingDistributor+".count"] = 1
// Alerting stats
alertingUsageStats, err := uss.AlertingUsageStats.QueryUsageStats()
@ -170,10 +183,10 @@ func (uss *UsageStatsService) GetUsageReport(ctx context.Context) (UsageReport,
// Add stats about auth configuration
authTypes := map[string]bool{}
authTypes["anonymous"] = setting.AnonymousEnabled
authTypes["basic_auth"] = setting.BasicAuthEnabled
authTypes["ldap"] = setting.LDAPEnabled
authTypes["auth_proxy"] = setting.AuthProxyEnabled
authTypes["anonymous"] = uss.Cfg.AnonymousEnabled
authTypes["basic_auth"] = uss.Cfg.BasicAuthEnabled
authTypes["ldap"] = uss.Cfg.LDAPEnabled
authTypes["auth_proxy"] = uss.Cfg.AuthProxyEnabled
for provider, enabled := range uss.oauthProviders {
authTypes["oauth_"+provider] = enabled
@ -221,7 +234,7 @@ func (uss *UsageStatsService) RegisterMetric(name string, fn MetricFunc) {
}
func (uss *UsageStatsService) sendUsageStats(ctx context.Context) error {
if !setting.ReportingEnabled {
if !uss.Cfg.ReportingEnabled {
return nil
}
@ -237,9 +250,17 @@ func (uss *UsageStatsService) sendUsageStats(ctx context.Context) error {
return err
}
data := bytes.NewBuffer(out)
sendUsageStats(data)
client := http.Client{Timeout: 5 * time.Second}
return nil
}
// sendUsageStats sends usage statistics.
//
// Stubbable by tests.
var sendUsageStats = func(data *bytes.Buffer) {
go func() {
client := http.Client{Timeout: 5 * time.Second}
resp, err := client.Post(usageStatsURL, "application/json", data)
if err != nil {
metricsLogger.Error("Failed to send usage stats", "err", err)
@ -249,8 +270,6 @@ func (uss *UsageStatsService) sendUsageStats(ctx context.Context) error {
metricsLogger.Warn("Failed to close response body", "err", err)
}
}()
return nil
}
func (uss *UsageStatsService) updateTotalStats() {
@ -298,33 +317,3 @@ func (uss *UsageStatsService) shouldBeReported(dsType string) bool {
return ds.Signature.IsValid() || ds.Signature.IsInternal()
}
func getEdition() string {
edition := "oss"
if setting.IsEnterprise {
edition = "enterprise"
}
return edition
}
func getEnterpriseEditionCount() int {
if setting.IsEnterprise {
return 1
}
return 0
}
func getOssEditionCount() int {
if setting.IsEnterprise {
return 0
}
return 1
}
func getValidLicenseCount(validLicense bool) int {
if validLicense {
return 1
}
return 0
}

View File

@ -6,7 +6,6 @@ import (
"errors"
"io/ioutil"
"runtime"
"sync"
"testing"
"time"
@ -40,13 +39,8 @@ func Test_InterfaceContractValidity(t *testing.T) {
func TestMetrics(t *testing.T) {
t.Run("When sending usage stats", func(t *testing.T) {
setupSomeDataSourcePlugins(t)
uss := &UsageStatsService{
Bus: bus.New(),
SQLStore: sqlstore.InitTestDB(t),
License: &licensing.OSSLicensingService{},
}
uss := createService(t, setting.Cfg{})
setupSomeDataSourcePlugins(t, uss)
var getSystemStatsQuery *models.GetSystemStatsQuery
uss.Bus.AddHandler(func(query *models.GetSystemStatsQuery) error {
@ -166,22 +160,6 @@ func TestMetrics(t *testing.T) {
createConcurrentTokens(t, uss.SQLStore)
uss.AlertingUsageStats = &alertingUsageMock{}
var wg sync.WaitGroup
var responseBuffer *bytes.Buffer
var req *http.Request
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
req = r
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed to read response body, err=%v", err)
}
responseBuffer = bytes.NewBuffer(buf)
wg.Done()
}))
usageStatsURL = ts.URL
defer ts.Close()
uss.oauthProviders = map[string]bool{
"github": true,
"gitlab": true,
@ -195,125 +173,163 @@ func TestMetrics(t *testing.T) {
require.NoError(t, err)
t.Run("Given reporting not enabled and sending usage stats", func(t *testing.T) {
setting.ReportingEnabled = false
origSendUsageStats := sendUsageStats
t.Cleanup(func() {
sendUsageStats = origSendUsageStats
})
statsSent := false
sendUsageStats = func(*bytes.Buffer) {
statsSent = true
}
uss.Cfg.ReportingEnabled = false
err := uss.sendUsageStats(context.Background())
require.NoError(t, err)
t.Run("Should not gather stats or call http endpoint", func(t *testing.T) {
assert.Nil(t, getSystemStatsQuery)
assert.Nil(t, getDataSourceStatsQuery)
assert.Nil(t, getDataSourceAccessStatsQuery)
assert.Nil(t, req)
})
require.False(t, statsSent)
assert.Nil(t, getSystemStatsQuery)
assert.Nil(t, getDataSourceStatsQuery)
assert.Nil(t, getDataSourceAccessStatsQuery)
})
t.Run("Given reporting enabled and sending usage stats", func(t *testing.T) {
setting.ReportingEnabled = true
setting.BuildVersion = "5.0.0"
setting.AnonymousEnabled = true
setting.BasicAuthEnabled = true
setting.LDAPEnabled = true
setting.AuthProxyEnabled = true
setting.Packaging = "deb"
setting.ReportingDistributor = "hosted-grafana"
t.Run("Given reporting enabled, stats should be gathered and sent to HTTP endpoint", func(t *testing.T) {
origCfg := uss.Cfg
t.Cleanup(func() {
uss.Cfg = origCfg
})
uss.Cfg = &setting.Cfg{
ReportingEnabled: true,
BuildVersion: "5.0.0",
AnonymousEnabled: true,
BasicAuthEnabled: true,
LDAPEnabled: true,
AuthProxyEnabled: true,
Packaging: "deb",
ReportingDistributor: "hosted-grafana",
}
ch := make(chan httpResp)
ticker := time.NewTicker(2 * time.Second)
ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Logf("Fake HTTP handler received an error: %s", err.Error())
ch <- httpResp{
err: err,
}
return
}
require.NoError(t, err, "Failed to read response body, err=%v", err)
t.Logf("Fake HTTP handler received a response")
ch <- httpResp{
responseBuffer: bytes.NewBuffer(buf),
req: r,
}
}))
t.Cleanup(ts.Close)
t.Cleanup(func() {
close(ch)
})
usageStatsURL = ts.URL
wg.Add(1)
err := uss.sendUsageStats(context.Background())
require.NoError(t, err)
t.Run("Should gather stats and call http endpoint", func(t *testing.T) {
if waitTimeout(&wg, 2*time.Second) {
t.Fatalf("Timed out waiting for http request")
}
// Wait for fake HTTP server to receive a request
var resp httpResp
select {
case resp = <-ch:
require.NoError(t, resp.err, "Fake server experienced an error")
case <-ticker.C:
t.Fatalf("Timed out waiting for HTTP request")
}
assert.NotNil(t, getSystemStatsQuery)
assert.NotNil(t, getDataSourceStatsQuery)
assert.NotNil(t, getDataSourceAccessStatsQuery)
assert.NotNil(t, getAlertNotifierUsageStatsQuery)
assert.NotNil(t, req)
t.Logf("Received response from fake HTTP server: %+v\n", resp)
assert.Equal(t, http.MethodPost, req.Method)
assert.Equal(t, "application/json", req.Header.Get("Content-Type"))
assert.NotNil(t, getSystemStatsQuery)
assert.NotNil(t, getDataSourceStatsQuery)
assert.NotNil(t, getDataSourceAccessStatsQuery)
assert.NotNil(t, getAlertNotifierUsageStatsQuery)
assert.NotNil(t, resp.req)
assert.NotNil(t, responseBuffer)
assert.Equal(t, http.MethodPost, resp.req.Method)
assert.Equal(t, "application/json", resp.req.Header.Get("Content-Type"))
j, err := simplejson.NewFromReader(responseBuffer)
assert.Nil(t, err)
require.NotNil(t, resp.responseBuffer)
assert.Equal(t, "5_0_0", j.Get("version").MustString())
assert.Equal(t, runtime.GOOS, j.Get("os").MustString())
assert.Equal(t, runtime.GOARCH, j.Get("arch").MustString())
j, err := simplejson.NewFromReader(resp.responseBuffer)
require.NoError(t, err)
metrics := j.Get("metrics")
assert.Equal(t, getSystemStatsQuery.Result.Dashboards, metrics.Get("stats.dashboards.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.Users, metrics.Get("stats.users.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.Orgs, metrics.Get("stats.orgs.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.Playlists, metrics.Get("stats.playlist.count").MustInt64())
assert.Equal(t, len(manager.Apps), metrics.Get("stats.plugins.apps.count").MustInt())
assert.Equal(t, len(manager.Panels), metrics.Get("stats.plugins.panels.count").MustInt())
assert.Equal(t, len(manager.DataSources), metrics.Get("stats.plugins.datasources.count").MustInt())
assert.Equal(t, getSystemStatsQuery.Result.Alerts, metrics.Get("stats.alerts.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.ActiveUsers, metrics.Get("stats.active_users.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.Datasources, metrics.Get("stats.datasources.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.Stars, metrics.Get("stats.stars.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.Folders, metrics.Get("stats.folders.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.DashboardPermissions, metrics.Get("stats.dashboard_permissions.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.FolderPermissions, metrics.Get("stats.folder_permissions.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.ProvisionedDashboards, metrics.Get("stats.provisioned_dashboards.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.Snapshots, metrics.Get("stats.snapshots.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.Teams, metrics.Get("stats.teams.count").MustInt64())
assert.Equal(t, 15, metrics.Get("stats.total_auth_token.count").MustInt())
assert.Equal(t, 5, metrics.Get("stats.avg_auth_token_per_user.count").MustInt())
assert.Equal(t, 16, metrics.Get("stats.dashboard_versions.count").MustInt())
assert.Equal(t, 17, metrics.Get("stats.annotations.count").MustInt())
assert.Equal(t, "5_0_0", j.Get("version").MustString())
assert.Equal(t, runtime.GOOS, j.Get("os").MustString())
assert.Equal(t, runtime.GOARCH, j.Get("arch").MustString())
assert.Equal(t, 9, metrics.Get("stats.ds."+models.DS_ES+".count").MustInt())
assert.Equal(t, 10, metrics.Get("stats.ds."+models.DS_PROMETHEUS+".count").MustInt())
assert.Equal(t, 11+12, metrics.Get("stats.ds.other.count").MustInt())
metrics := j.Get("metrics")
assert.Equal(t, getSystemStatsQuery.Result.Dashboards, metrics.Get("stats.dashboards.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.Users, metrics.Get("stats.users.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.Orgs, metrics.Get("stats.orgs.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.Playlists, metrics.Get("stats.playlist.count").MustInt64())
assert.Equal(t, len(manager.Apps), metrics.Get("stats.plugins.apps.count").MustInt())
assert.Equal(t, len(manager.Panels), metrics.Get("stats.plugins.panels.count").MustInt())
assert.Equal(t, len(manager.DataSources), metrics.Get("stats.plugins.datasources.count").MustInt())
assert.Equal(t, getSystemStatsQuery.Result.Alerts, metrics.Get("stats.alerts.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.ActiveUsers, metrics.Get("stats.active_users.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.Datasources, metrics.Get("stats.datasources.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.Stars, metrics.Get("stats.stars.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.Folders, metrics.Get("stats.folders.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.DashboardPermissions, metrics.Get("stats.dashboard_permissions.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.FolderPermissions, metrics.Get("stats.folder_permissions.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.ProvisionedDashboards, metrics.Get("stats.provisioned_dashboards.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.Snapshots, metrics.Get("stats.snapshots.count").MustInt64())
assert.Equal(t, getSystemStatsQuery.Result.Teams, metrics.Get("stats.teams.count").MustInt64())
assert.Equal(t, 15, metrics.Get("stats.total_auth_token.count").MustInt())
assert.Equal(t, 5, metrics.Get("stats.avg_auth_token_per_user.count").MustInt())
assert.Equal(t, 16, metrics.Get("stats.dashboard_versions.count").MustInt())
assert.Equal(t, 17, metrics.Get("stats.annotations.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.ds_access."+models.DS_ES+".direct.count").MustInt())
assert.Equal(t, 2, metrics.Get("stats.ds_access."+models.DS_ES+".proxy.count").MustInt())
assert.Equal(t, 3, metrics.Get("stats.ds_access."+models.DS_PROMETHEUS+".proxy.count").MustInt())
assert.Equal(t, 6+7, metrics.Get("stats.ds_access.other.direct.count").MustInt())
assert.Equal(t, 4+8, metrics.Get("stats.ds_access.other.proxy.count").MustInt())
assert.Equal(t, 9, metrics.Get("stats.ds."+models.DS_ES+".count").MustInt())
assert.Equal(t, 10, metrics.Get("stats.ds."+models.DS_PROMETHEUS+".count").MustInt())
assert.Equal(t, 11+12, metrics.Get("stats.ds.other.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.alerting.ds.prometheus.count").MustInt())
assert.Equal(t, 2, metrics.Get("stats.alerting.ds.graphite.count").MustInt())
assert.Equal(t, 5, metrics.Get("stats.alerting.ds.mysql.count").MustInt())
assert.Equal(t, 90, metrics.Get("stats.alerting.ds.other.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.ds_access."+models.DS_ES+".direct.count").MustInt())
assert.Equal(t, 2, metrics.Get("stats.ds_access."+models.DS_ES+".proxy.count").MustInt())
assert.Equal(t, 3, metrics.Get("stats.ds_access."+models.DS_PROMETHEUS+".proxy.count").MustInt())
assert.Equal(t, 6+7, metrics.Get("stats.ds_access.other.direct.count").MustInt())
assert.Equal(t, 4+8, metrics.Get("stats.ds_access.other.proxy.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.alert_notifiers.slack.count").MustInt())
assert.Equal(t, 2, metrics.Get("stats.alert_notifiers.webhook.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.alerting.ds.prometheus.count").MustInt())
assert.Equal(t, 2, metrics.Get("stats.alerting.ds.graphite.count").MustInt())
assert.Equal(t, 5, metrics.Get("stats.alerting.ds.mysql.count").MustInt())
assert.Equal(t, 90, metrics.Get("stats.alerting.ds.other.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_enabled.anonymous.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_enabled.basic_auth.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_enabled.ldap.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_enabled.auth_proxy.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_github.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_gitlab.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_google.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_azuread.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_generic_oauth.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_grafana_com.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.alert_notifiers.slack.count").MustInt())
assert.Equal(t, 2, metrics.Get("stats.alert_notifiers.webhook.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.packaging.deb.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.distributor.hosted-grafana.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_enabled.anonymous.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_enabled.basic_auth.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_enabled.ldap.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_enabled.auth_proxy.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_github.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_gitlab.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_google.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_azuread.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_generic_oauth.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_grafana_com.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_token_per_user_le_3").MustInt())
assert.Equal(t, 2, metrics.Get("stats.auth_token_per_user_le_6").MustInt())
assert.Equal(t, 3, metrics.Get("stats.auth_token_per_user_le_9").MustInt())
assert.Equal(t, 4, metrics.Get("stats.auth_token_per_user_le_12").MustInt())
assert.Equal(t, 5, metrics.Get("stats.auth_token_per_user_le_15").MustInt())
assert.Equal(t, 6, metrics.Get("stats.auth_token_per_user_le_inf").MustInt())
})
assert.Equal(t, 1, metrics.Get("stats.packaging.deb.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.distributor.hosted-grafana.count").MustInt())
assert.Equal(t, 1, metrics.Get("stats.auth_token_per_user_le_3").MustInt())
assert.Equal(t, 2, metrics.Get("stats.auth_token_per_user_le_6").MustInt())
assert.Equal(t, 3, metrics.Get("stats.auth_token_per_user_le_9").MustInt())
assert.Equal(t, 4, metrics.Get("stats.auth_token_per_user_le_12").MustInt())
assert.Equal(t, 5, metrics.Get("stats.auth_token_per_user_le_15").MustInt())
assert.Equal(t, 6, metrics.Get("stats.auth_token_per_user_le_inf").MustInt())
})
})
t.Run("When updating total stats", func(t *testing.T) {
uss := &UsageStatsService{
Bus: bus.New(),
Cfg: setting.NewCfg(),
}
uss := createService(t, setting.Cfg{})
uss.Cfg.MetricsEndpointEnabled = true
uss.Cfg.MetricsEndpointDisableTotalStats = false
getSystemStatsWasCalled := false
@ -323,56 +339,44 @@ func TestMetrics(t *testing.T) {
return nil
})
t.Run("When metrics is disabled and total stats is enabled", func(t *testing.T) {
t.Run("When metrics is disabled and total stats is enabled, stats should not be updated", func(t *testing.T) {
uss.Cfg.MetricsEndpointEnabled = false
uss.Cfg.MetricsEndpointDisableTotalStats = false
t.Run("Should not update stats", func(t *testing.T) {
uss.updateTotalStats()
uss.updateTotalStats()
assert.False(t, getSystemStatsWasCalled)
})
assert.False(t, getSystemStatsWasCalled)
})
t.Run("When metrics is enabled and total stats is disabled", func(t *testing.T) {
t.Run("When metrics is enabled and total stats is disabled, stats should not be updated", func(t *testing.T) {
uss.Cfg.MetricsEndpointEnabled = true
uss.Cfg.MetricsEndpointDisableTotalStats = true
t.Run("Should not update stats", func(t *testing.T) {
uss.updateTotalStats()
uss.updateTotalStats()
assert.False(t, getSystemStatsWasCalled)
})
assert.False(t, getSystemStatsWasCalled)
})
t.Run("When metrics is disabled and total stats is disabled", func(t *testing.T) {
t.Run("When metrics is disabled and total stats is disabled, stats should not be updated", func(t *testing.T) {
uss.Cfg.MetricsEndpointEnabled = false
uss.Cfg.MetricsEndpointDisableTotalStats = true
t.Run("Should not update stats", func(t *testing.T) {
uss.updateTotalStats()
uss.updateTotalStats()
assert.False(t, getSystemStatsWasCalled)
})
assert.False(t, getSystemStatsWasCalled)
})
t.Run("When metrics is enabled and total stats is enabled", func(t *testing.T) {
t.Run("When metrics is enabled and total stats is enabled, stats should be updated", func(t *testing.T) {
uss.Cfg.MetricsEndpointEnabled = true
uss.Cfg.MetricsEndpointDisableTotalStats = false
t.Run("Should update stats", func(t *testing.T) {
uss.updateTotalStats()
uss.updateTotalStats()
assert.True(t, getSystemStatsWasCalled)
})
assert.True(t, getSystemStatsWasCalled)
})
})
t.Run("When registering a metric", func(t *testing.T) {
uss := &UsageStatsService{
Bus: bus.New(),
Cfg: setting.NewCfg(),
externalMetrics: make(map[string]MetricFunc),
}
uss := createService(t, setting.Cfg{})
metricName := "stats.test_metric.count"
t.Run("Adds a new metric to the external metrics", func(t *testing.T) {
@ -380,37 +384,31 @@ func TestMetrics(t *testing.T) {
return 1, nil
})
metric, _ := uss.externalMetrics[metricName]()
metric, err := uss.externalMetrics[metricName]()
require.NoError(t, err)
assert.Equal(t, 1, metric)
})
t.Run("When metric already exists", func(t *testing.T) {
t.Run("When metric already exists, the metric should be overridden", func(t *testing.T) {
uss.RegisterMetric(metricName, func() (interface{}, error) {
return 1, nil
})
metric, _ := uss.externalMetrics[metricName]()
metric, err := uss.externalMetrics[metricName]()
require.NoError(t, err)
assert.Equal(t, 1, metric)
t.Run("Overrides the metric", func(t *testing.T) {
uss.RegisterMetric(metricName, func() (interface{}, error) {
return 2, nil
})
newMetric, _ := uss.externalMetrics[metricName]()
assert.Equal(t, 2, newMetric)
uss.RegisterMetric(metricName, func() (interface{}, error) {
return 2, nil
})
newMetric, err := uss.externalMetrics[metricName]()
require.NoError(t, err)
assert.Equal(t, 2, newMetric)
})
})
t.Run("When getting usage report", func(t *testing.T) {
uss := &UsageStatsService{
Bus: bus.New(),
Cfg: setting.NewCfg(),
SQLStore: sqlstore.InitTestDB(t),
License: &licensing.OSSLicensingService{},
AlertingUsageStats: &alertingUsageMock{},
externalMetrics: make(map[string]MetricFunc),
}
uss := createService(t, setting.Cfg{})
metricName := "stats.test_metric.count"
uss.Bus.AddHandler(func(query *models.GetSystemStatsQuery) error {
@ -453,7 +451,7 @@ func TestMetrics(t *testing.T) {
})
report, err := uss.GetUsageReport(context.Background())
assert.Nil(t, err, "Expected no error")
require.NoError(t, err, "Expected no error")
metric := report.Metrics[metricName]
assert.Equal(t, 1, metric)
@ -461,24 +459,18 @@ func TestMetrics(t *testing.T) {
})
t.Run("When registering external metrics", func(t *testing.T) {
uss := &UsageStatsService{
Bus: bus.New(),
Cfg: setting.NewCfg(),
externalMetrics: make(map[string]MetricFunc),
}
uss := createService(t, setting.Cfg{})
metrics := map[string]interface{}{"stats.test_metric.count": 1, "stats.test_metric_second.count": 2}
extMetricName := "stats.test_external_metric.count"
t.Run("Should add to metrics", func(t *testing.T) {
uss.RegisterMetric(extMetricName, func() (interface{}, error) {
return 1, nil
})
uss.registerExternalMetrics(metrics)
assert.Equal(t, 1, metrics[extMetricName])
uss.RegisterMetric(extMetricName, func() (interface{}, error) {
return 1, nil
})
uss.registerExternalMetrics(metrics)
assert.Equal(t, 1, metrics[extMetricName])
t.Run("When loading a metric results to an error", func(t *testing.T) {
uss.RegisterMetric(extMetricName, func() (interface{}, error) {
return 1, nil
@ -495,7 +487,7 @@ func TestMetrics(t *testing.T) {
extErrorMetric := metrics[extErrorMetricName]
extMetric := metrics[extMetricName]
assert.Nil(t, extErrorMetric, "Invalid metric should not be added")
require.Nil(t, extErrorMetric, "Invalid metric should not be added")
assert.Equal(t, 1, extMetric)
assert.Len(t, metrics, 3, "Expected only one available metric")
})
@ -503,20 +495,6 @@ func TestMetrics(t *testing.T) {
})
}
func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
c := make(chan struct{})
go func() {
defer close(c)
wg.Wait()
}()
select {
case <-c:
return false // completed normally
case <-time.After(timeout):
return true // timed out
}
}
type alertingUsageMock struct{}
func (aum *alertingUsageMock) QueryUsageStats() (*alerting.UsageStats, error) {
@ -530,40 +508,58 @@ func (aum *alertingUsageMock) QueryUsageStats() (*alerting.UsageStats, error) {
}, nil
}
func setupSomeDataSourcePlugins(t *testing.T) {
func setupSomeDataSourcePlugins(t *testing.T, uss *UsageStatsService) {
t.Helper()
originalDataSources := manager.DataSources
t.Cleanup(func() { manager.DataSources = originalDataSources })
manager.DataSources = make(map[string]*plugins.DataSourcePlugin)
manager.DataSources[models.DS_ES] = &plugins.DataSourcePlugin{
FrontendPluginBase: plugins.FrontendPluginBase{
PluginBase: plugins.PluginBase{
Signature: "internal",
manager.DataSources = map[string]*plugins.DataSourcePlugin{
models.DS_ES: {
FrontendPluginBase: plugins.FrontendPluginBase{
PluginBase: plugins.PluginBase{
Signature: "internal",
},
},
},
}
manager.DataSources[models.DS_PROMETHEUS] = &plugins.DataSourcePlugin{
FrontendPluginBase: plugins.FrontendPluginBase{
PluginBase: plugins.PluginBase{
Signature: "internal",
models.DS_PROMETHEUS: {
FrontendPluginBase: plugins.FrontendPluginBase{
PluginBase: plugins.PluginBase{
Signature: "internal",
},
},
},
}
manager.DataSources[models.DS_GRAPHITE] = &plugins.DataSourcePlugin{
FrontendPluginBase: plugins.FrontendPluginBase{
PluginBase: plugins.PluginBase{
Signature: "internal",
models.DS_GRAPHITE: {
FrontendPluginBase: plugins.FrontendPluginBase{
PluginBase: plugins.PluginBase{
Signature: "internal",
},
},
},
}
manager.DataSources[models.DS_MYSQL] = &plugins.DataSourcePlugin{
FrontendPluginBase: plugins.FrontendPluginBase{
PluginBase: plugins.PluginBase{
Signature: "internal",
models.DS_MYSQL: {
FrontendPluginBase: plugins.FrontendPluginBase{
PluginBase: plugins.PluginBase{
Signature: "internal",
},
},
},
}
}
type httpResp struct {
req *http.Request
responseBuffer *bytes.Buffer
err error
}
func createService(t *testing.T, cfg setting.Cfg) *UsageStatsService {
t.Helper()
return &UsageStatsService{
Bus: bus.New(),
Cfg: &cfg,
SQLStore: sqlstore.InitTestDB(t),
License: &licensing.OSSLicensingService{},
AlertingUsageStats: &alertingUsageMock{},
externalMetrics: make(map[string]MetricFunc),
}
}

View File

@ -141,10 +141,8 @@ var (
appliedEnvOverrides []string
// analytics
ReportingEnabled bool
ReportingDistributor string
GoogleAnalyticsId string
GoogleTagManagerId string
GoogleAnalyticsId string
GoogleTagManagerId string
// LDAP
LDAPEnabled bool
@ -337,7 +335,9 @@ type Cfg struct {
Env string
// Analytics
CheckForUpdates bool
CheckForUpdates bool
ReportingDistributor string
ReportingEnabled bool
// LDAP
LDAPEnabled bool
@ -831,10 +831,10 @@ func (cfg *Cfg) Load(args *CommandLineArgs) error {
cfg.CheckForUpdates = analytics.Key("check_for_updates").MustBool(true)
GoogleAnalyticsId = analytics.Key("google_analytics_ua_id").String()
GoogleTagManagerId = analytics.Key("google_tag_manager_id").String()
ReportingEnabled = analytics.Key("reporting_enabled").MustBool(true)
ReportingDistributor = analytics.Key("reporting_distributor").MustString("grafana-labs")
if len(ReportingDistributor) >= 100 {
ReportingDistributor = ReportingDistributor[:100]
cfg.ReportingEnabled = analytics.Key("reporting_enabled").MustBool(true)
cfg.ReportingDistributor = analytics.Key("reporting_distributor").MustString("grafana-labs")
if len(cfg.ReportingDistributor) >= 100 {
cfg.ReportingDistributor = cfg.ReportingDistributor[:100]
}
if err := readAlertingSettings(iniFile); err != nil {