mirror of
https://github.com/grafana/grafana.git
synced 2024-11-22 08:56:43 -06:00
Live: rely on app url for origin check (#35983)
This commit is contained in:
parent
5da8f3e258
commit
5bbf45592e
@ -84,6 +84,7 @@ type testState struct {
|
||||
func newTestLive(t *testing.T) *live.GrafanaLive {
|
||||
gLive := live.NewGrafanaLive()
|
||||
gLive.RouteRegister = routing.NewRouteRegister()
|
||||
gLive.Cfg = &setting.Cfg{AppURL: "http://localhost:3000/"}
|
||||
err := gLive.Init()
|
||||
require.NoError(t, err)
|
||||
return gLive
|
||||
|
@ -6,7 +6,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -230,15 +232,26 @@ func (g *GrafanaLive) Init() error {
|
||||
return err
|
||||
}
|
||||
|
||||
appURL, err := url.Parse(g.Cfg.AppURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing AppURL %s: %w", g.Cfg.AppURL, err)
|
||||
}
|
||||
|
||||
// Use a pure websocket transport.
|
||||
wsHandler := centrifuge.NewWebsocketHandler(node, centrifuge.WebsocketConfig{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return checkOrigin(r, appURL)
|
||||
},
|
||||
})
|
||||
|
||||
pushWSHandler := pushws.NewHandler(g.ManagedStreamRunner, pushws.Config{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return checkOrigin(r, appURL)
|
||||
},
|
||||
})
|
||||
|
||||
g.websocketHandler = func(ctx *models.ReqContext) {
|
||||
@ -277,6 +290,23 @@ func (g *GrafanaLive) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkOrigin(r *http.Request, appURL *url.URL) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return true
|
||||
}
|
||||
originURL, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to parse request origin", "error", err, "origin", origin)
|
||||
return false
|
||||
}
|
||||
if !strings.EqualFold(originURL.Scheme, appURL.Scheme) || !strings.EqualFold(originURL.Host, appURL.Host) {
|
||||
logger.Warn("Request Origin is not authorized", "origin", origin, "appUrl", appURL.String())
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func runConcurrentlyIfNeeded(ctx context.Context, semaphore chan struct{}, fn func()) error {
|
||||
if cap(semaphore) > 1 {
|
||||
select {
|
||||
|
@ -2,6 +2,8 @@ package live
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -50,3 +52,63 @@ func Test_runConcurrentlyIfNeeded_DeadlineExceeded(t *testing.T) {
|
||||
err := runConcurrentlyIfNeeded(ctx, semaphore, f)
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
func TestCheckOrigin(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
origin string
|
||||
appURL string
|
||||
success bool
|
||||
}{
|
||||
{
|
||||
name: "empty_origin",
|
||||
origin: "",
|
||||
appURL: "http://localhost:3000/",
|
||||
success: true,
|
||||
},
|
||||
{
|
||||
name: "valid_origin",
|
||||
origin: "http://localhost:3000",
|
||||
appURL: "http://localhost:3000/",
|
||||
success: true,
|
||||
},
|
||||
{
|
||||
name: "unauthorized_origin",
|
||||
origin: "http://localhost:8000",
|
||||
appURL: "http://localhost:3000/",
|
||||
success: false,
|
||||
},
|
||||
{
|
||||
name: "bad_origin",
|
||||
origin: ":::http://localhost:8000",
|
||||
appURL: "http://localhost:3000/",
|
||||
success: false,
|
||||
},
|
||||
{
|
||||
name: "different_scheme",
|
||||
origin: "http://example.com",
|
||||
appURL: "https://example.com",
|
||||
success: false,
|
||||
},
|
||||
{
|
||||
name: "authorized_case_insensitive",
|
||||
origin: "https://examplE.com",
|
||||
appURL: "https://example.com",
|
||||
success: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
appURL, err := url.Parse(tc.appURL)
|
||||
require.NoError(t, err)
|
||||
r := httptest.NewRequest("GET", tc.appURL, nil)
|
||||
r.Header.Set("Origin", tc.origin)
|
||||
require.Equal(t, tc.success, checkOrigin(r, appURL),
|
||||
"origin %s, appURL: %s", tc.origin, tc.appURL,
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user