diff --git a/conf/defaults.ini b/conf/defaults.ini index 5ca7494f1ec..c361b82e670 100644 --- a/conf/defaults.ini +++ b/conf/defaults.ini @@ -901,6 +901,10 @@ plugin_catalog_url = https://grafana.com/grafana/plugins/ # tuning. 0 disables Live, -1 means unlimited connections. max_connections = 100 +# allowed_origins is a comma-separated list of origins that can establish connection with Grafana Live. +# If not set then origin will be matched over root_url. Supports globbing: see https://github.com/gobwas/glob. +allowed_origins = + # engine defines an HA (high availability) engine to use for Grafana Live. By default no engine used - in # this case Live features work only on a single Grafana server. # Available options: "redis". diff --git a/conf/sample.ini b/conf/sample.ini index 3930f1764d7..3e58183c5b7 100644 --- a/conf/sample.ini +++ b/conf/sample.ini @@ -887,6 +887,10 @@ # tuning. 0 disables Live, -1 means unlimited connections. ;max_connections = 100 +# allowed_origins is a comma-separated list of origins that can establish connection with Grafana Live. +# If not set then origin will be matched over root_url. Supports globbing: see https://github.com/gobwas/glob. +;allowed_origins = + # engine defines an HA (high availability) engine to use for Grafana Live. By default no engine used - in # this case Live features work only on a single Grafana server. Available options: "redis". # Setting ha_engine is an EXPERIMENTAL feature. diff --git a/pkg/services/live/live.go b/pkg/services/live/live.go index ded3d3c45f8..7b3ab5504f8 100644 --- a/pkg/services/live/live.go +++ b/pkg/services/live/live.go @@ -10,6 +10,8 @@ import ( "sync" "time" + "github.com/gobwas/glob" + "github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/api/response" "github.com/grafana/grafana/pkg/api/routing" @@ -318,21 +320,21 @@ func (g *GrafanaLive) Init() error { return fmt.Errorf("error parsing AppURL %s: %w", g.Cfg.AppURL, err) } + originPatterns := g.Cfg.LiveAllowedOrigins + originGlobs, _ := setting.GetAllowedOriginGlobs(originPatterns) // error already checked on config load. + checkOrigin := getCheckOriginFunc(appURL, originPatterns, originGlobs) + // Use a pure websocket transport. wsHandler := centrifuge.NewWebsocketHandler(node, centrifuge.WebsocketConfig{ ReadBufferSize: 1024, WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - return checkOrigin(r, appURL) - }, + CheckOrigin: checkOrigin, }) pushWSHandler := pushws.NewHandler(g.ManagedStreamRunner, pushws.Config{ ReadBufferSize: 1024, WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - return checkOrigin(r, appURL) - }, + CheckOrigin: checkOrigin, }) g.websocketHandler = func(ctx *models.ReqContext) { @@ -371,21 +373,44 @@ func (g *GrafanaLive) Init() error { return nil } -func checkOrigin(r *http.Request, appURL *url.URL) bool { - origin := r.Header.Get("Origin") - if origin == "" { +func getCheckOriginFunc(appURL *url.URL, originPatterns []string, originGlobs []glob.Glob) func(r *http.Request) bool { + return func(r *http.Request) bool { + origin := r.Header.Get("Origin") + if origin == "" { + return true + } + if len(originPatterns) == 1 && originPatterns[0] == "*" { + // fast path for *. + return true + } + ok, err := checkAllowedOrigin(strings.ToLower(origin), appURL, originGlobs) + if err != nil { + logger.Warn("Error parsing request origin", "error", err, "origin", origin) + return false + } + if !ok { + logger.Warn("Request Origin is not authorized", "origin", origin, "appUrl", appURL.String(), "allowedOrigins", strings.Join(originPatterns, ",")) + return false + } return true } +} + +func checkAllowedOrigin(origin string, appURL *url.URL, originGlobs []glob.Glob) (bool, error) { originURL, err := url.Parse(origin) if err != nil { logger.Warn("Failed to parse request origin", "error", err, "origin", origin) - return false + return false, err } - if !strings.EqualFold(originURL.Scheme, appURL.Scheme) || !strings.EqualFold(originURL.Host, appURL.Host) { - logger.Warn("Request Origin is not authorized", "origin", origin, "appUrl", appURL.String()) - return false + if strings.EqualFold(originURL.Scheme, appURL.Scheme) && strings.EqualFold(originURL.Host, appURL.Host) { + return true, nil } - return true + for _, pattern := range originGlobs { + if pattern.Match(origin) { + return true, nil + } + } + return false, nil } func runConcurrentlyIfNeeded(ctx context.Context, semaphore chan struct{}, fn func()) error { diff --git a/pkg/services/live/live_test.go b/pkg/services/live/live_test.go index 06a501bc197..44db4556ddc 100644 --- a/pkg/services/live/live_test.go +++ b/pkg/services/live/live_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/grafana/grafana/pkg/setting" + "github.com/stretchr/testify/require" ) @@ -55,10 +57,11 @@ func Test_runConcurrentlyIfNeeded_DeadlineExceeded(t *testing.T) { func TestCheckOrigin(t *testing.T) { testCases := []struct { - name string - origin string - appURL string - success bool + name string + origin string + appURL string + allowedOrigins []string + success bool }{ { name: "empty_origin", @@ -96,6 +99,27 @@ func TestCheckOrigin(t *testing.T) { appURL: "https://example.com", success: true, }, + { + name: "authorized_allowed_origins", + origin: "https://test.example.com", + appURL: "http://localhost:3000/", + allowedOrigins: []string{"https://test.example.com"}, + success: true, + }, + { + name: "authorized_allowed_origins_pattern", + origin: "https://test.example.com", + appURL: "http://localhost:3000/", + allowedOrigins: []string{"https://*.example.com"}, + success: true, + }, + { + name: "authorized_allowed_origins_all", + origin: "https://test.example.com", + appURL: "http://localhost:3000/", + allowedOrigins: []string{"*"}, + success: true, + }, } for _, tc := range testCases { @@ -104,9 +128,15 @@ func TestCheckOrigin(t *testing.T) { t.Parallel() appURL, err := url.Parse(tc.appURL) require.NoError(t, err) + + originGlobs, err := setting.GetAllowedOriginGlobs(tc.allowedOrigins) + require.NoError(t, err) + + checkOrigin := getCheckOriginFunc(appURL, tc.allowedOrigins, originGlobs) + r := httptest.NewRequest("GET", tc.appURL, nil) r.Header.Set("Origin", tc.origin) - require.Equal(t, tc.success, checkOrigin(r, appURL), + require.Equal(t, tc.success, checkOrigin(r), "origin %s, appURL: %s", tc.origin, tc.appURL, ) }) diff --git a/pkg/setting/setting.go b/pkg/setting/setting.go index 2232d88b970..b019eda8ef8 100644 --- a/pkg/setting/setting.go +++ b/pkg/setting/setting.go @@ -17,6 +17,8 @@ import ( "strings" "time" + "github.com/gobwas/glob" + "github.com/prometheus/common/model" ini "gopkg.in/ini.v1" @@ -391,6 +393,9 @@ type Cfg struct { LiveHAEngine string // LiveHAEngineAddress is a connection address for Live HA engine. LiveHAEngineAddress string + // LiveAllowedOrigins is a set of origins accepted by Live. If not provided + // then Live uses AppURL as the only allowed origin. + LiveAllowedOrigins []string // Grafana.com URL GrafanaComURL string @@ -1446,6 +1451,19 @@ func (cfg *Cfg) readDataSourcesSettings() { cfg.DataSourceLimit = datasources.Key("datasource_limit").MustInt(5000) } +func GetAllowedOriginGlobs(originPatterns []string) ([]glob.Glob, error) { + var originGlobs []glob.Glob + allowedOrigins := originPatterns + for _, originPattern := range allowedOrigins { + g, err := glob.Compile(originPattern) + if err != nil { + return nil, fmt.Errorf("error parsing origin pattern: %v", err) + } + originGlobs = append(originGlobs, g) + } + return originGlobs, nil +} + func (cfg *Cfg) readLiveSettings(iniFile *ini.File) error { section := iniFile.Section("live") cfg.LiveMaxConnections = section.Key("max_connections").MustInt(100) @@ -1459,5 +1477,20 @@ func (cfg *Cfg) readLiveSettings(iniFile *ini.File) error { return fmt.Errorf("unsupported live HA engine type: %s", cfg.LiveHAEngine) } cfg.LiveHAEngineAddress = section.Key("ha_engine_address").MustString("127.0.0.1:6379") + + var originPatterns []string + allowedOrigins := section.Key("allowed_origins").MustString("") + for _, originPattern := range strings.Split(allowedOrigins, ",") { + originPattern = strings.TrimSpace(originPattern) + if originPattern == "" { + continue + } + originPatterns = append(originPatterns, originPattern) + } + _, err := GetAllowedOriginGlobs(originPatterns) + if err != nil { + return err + } + cfg.LiveAllowedOrigins = originPatterns return nil }