mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
live: add ws endpoint to push into pipeline (#41534)
This commit is contained in:
parent
f3f441f4ec
commit
1700b2c2f3
@ -435,7 +435,7 @@ func (hs *HTTPServer) registerRoutes() {
|
||||
|
||||
if hs.Cfg.FeatureToggles["live-pipeline"] {
|
||||
// POST Live data to be processed according to channel rules.
|
||||
liveRoute.Post("/push/:streamId/:path", hs.LivePushGateway.HandlePath)
|
||||
liveRoute.Post("/pipeline/push/*", hs.LivePushGateway.HandlePipelinePush)
|
||||
liveRoute.Post("/pipeline-convert-test", routing.Wrap(hs.Live.HandlePipelineConvertTestHTTP), reqOrgAdmin)
|
||||
liveRoute.Get("/pipeline-entities", routing.Wrap(hs.Live.HandlePipelineEntitiesListHTTP), reqOrgAdmin)
|
||||
liveRoute.Get("/channel-rules", routing.Wrap(hs.Live.HandleChannelRulesListHTTP), reqOrgAdmin)
|
||||
|
@ -321,6 +321,12 @@ func ProvideService(plugCtxProvider *plugincontext.Provider, cfg *setting.Cfg, r
|
||||
CheckOrigin: checkOrigin,
|
||||
})
|
||||
|
||||
pushPipelineWSHandler := pushws.NewPipelinePushHandler(g.Pipeline, pushws.Config{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: checkOrigin,
|
||||
})
|
||||
|
||||
g.websocketHandler = func(ctx *models.ReqContext) {
|
||||
user := ctx.SignedInUser
|
||||
|
||||
@ -342,12 +348,21 @@ func ProvideService(plugCtxProvider *plugincontext.Provider, cfg *setting.Cfg, r
|
||||
pushWSHandler.ServeHTTP(ctx.Resp, r)
|
||||
}
|
||||
|
||||
g.pushPipelineWebsocketHandler = func(ctx *models.ReqContext) {
|
||||
user := ctx.SignedInUser
|
||||
newCtx := livecontext.SetContextSignedUser(ctx.Req.Context(), user)
|
||||
newCtx = livecontext.SetContextChannelID(newCtx, web.Params(ctx.Req)["*"])
|
||||
r := ctx.Req.WithContext(newCtx)
|
||||
pushPipelineWSHandler.ServeHTTP(ctx.Resp, r)
|
||||
}
|
||||
|
||||
g.RouteRegister.Group("/api/live", func(group routing.RouteRegister) {
|
||||
group.Get("/ws", g.websocketHandler)
|
||||
}, middleware.ReqSignedIn)
|
||||
|
||||
g.RouteRegister.Group("/api/live", func(group routing.RouteRegister) {
|
||||
group.Get("/push/:streamId", g.pushWebsocketHandler)
|
||||
group.Get("/pipeline/push/*", g.pushPipelineWebsocketHandler)
|
||||
}, middleware.ReqOrgAdmin)
|
||||
|
||||
g.registerUsageMetrics()
|
||||
@ -375,8 +390,9 @@ type GrafanaLive struct {
|
||||
surveyCaller *survey.Caller
|
||||
|
||||
// Websocket handlers
|
||||
websocketHandler interface{}
|
||||
pushWebsocketHandler interface{}
|
||||
websocketHandler interface{}
|
||||
pushWebsocketHandler interface{}
|
||||
pushPipelineWebsocketHandler interface{}
|
||||
|
||||
// Full channel handler
|
||||
channels map[string]models.ChannelHandler
|
||||
|
@ -37,3 +37,18 @@ func GetContextStreamID(ctx context.Context) (string, bool) {
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
type channelIDContextKey struct{}
|
||||
|
||||
func SetContextChannelID(ctx context.Context, channelID string) context.Context {
|
||||
ctx = context.WithValue(ctx, channelIDContextKey{}, channelID)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func GetContextChannelID(ctx context.Context) (string, bool) {
|
||||
if val := ctx.Value(channelIDContextKey{}); val != nil {
|
||||
values, ok := val.(string)
|
||||
return values, ok
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
@ -64,21 +64,21 @@ func postTestData() {
|
||||
jsonData, _ := json.Marshal(d)
|
||||
log.Println(string(jsonData))
|
||||
|
||||
req, _ := http.NewRequest("POST", "http://localhost:3000/api/live/push/json/auto", bytes.NewReader(jsonData))
|
||||
req, _ := http.NewRequest("POST", "http://localhost:3000/api/live/pipeline/push/stream/json/auto", bytes.NewReader(jsonData))
|
||||
req.Header.Set("Authorization", "Bearer "+os.Getenv("GF_TOKEN"))
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
req, _ = http.NewRequest("POST", "http://localhost:3000/api/live/push/json/tip", bytes.NewReader(jsonData))
|
||||
req, _ = http.NewRequest("POST", "http://localhost:3000/api/live/push/pipeline/push/stream/json/tip", bytes.NewReader(jsonData))
|
||||
req.Header.Set("Authorization", "Bearer "+os.Getenv("GF_TOKEN"))
|
||||
resp, err = http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
req, _ = http.NewRequest("POST", "http://localhost:3000/api/live/push/json/exact", bytes.NewReader(jsonData))
|
||||
req, _ = http.NewRequest("POST", "http://localhost:3000/api/live/pipeline/push/stream/json/exact", bytes.NewReader(jsonData))
|
||||
req.Header.Set("Authorization", "Bearer "+os.Getenv("GF_TOKEN"))
|
||||
resp, err = http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
|
@ -96,9 +96,8 @@ func (g *Gateway) Handle(ctx *models.ReqContext) {
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) HandlePath(ctx *models.ReqContext) {
|
||||
streamID := web.Params(ctx.Req)[":streamId"]
|
||||
path := web.Params(ctx.Req)[":path"]
|
||||
func (g *Gateway) HandlePipelinePush(ctx *models.ReqContext) {
|
||||
channelID := web.Params(ctx.Req)["*"]
|
||||
|
||||
body, err := io.ReadAll(ctx.Req.Body)
|
||||
if err != nil {
|
||||
@ -108,13 +107,10 @@ func (g *Gateway) HandlePath(ctx *models.ReqContext) {
|
||||
}
|
||||
logger.Debug("Live channel push request",
|
||||
"protocol", "http",
|
||||
"streamId", streamID,
|
||||
"path", path,
|
||||
"channel", channelID,
|
||||
"bodyLength", len(body),
|
||||
)
|
||||
|
||||
channelID := "stream/" + streamID + "/" + path
|
||||
|
||||
ruleFound, err := g.GrafanaLive.Pipeline.ProcessInput(ctx.Req.Context(), ctx.OrgId, channelID, body)
|
||||
if err != nil {
|
||||
logger.Error("Pipeline input processing error", "error", err, "body", string(body))
|
||||
|
@ -1,200 +0,0 @@
|
||||
package pushws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/services/live/convert"
|
||||
"github.com/grafana/grafana/pkg/services/live/livecontext"
|
||||
"github.com/grafana/grafana/pkg/services/live/managedstream"
|
||||
"github.com/grafana/grafana/pkg/services/live/pushurl"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
liveDto "github.com/grafana/grafana-plugin-sdk-go/live"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = log.New("live.push_ws")
|
||||
)
|
||||
|
||||
// Handler handles WebSocket client connections that push data to Live.
|
||||
type Handler struct {
|
||||
managedStreamRunner *managedstream.Runner
|
||||
config Config
|
||||
upgrade *websocket.Upgrader
|
||||
converter *convert.Converter
|
||||
}
|
||||
|
||||
// Config represents config for Handler.
|
||||
type Config struct {
|
||||
// ReadBufferSize is a parameter that is used for raw websocket Upgrader.
|
||||
// If set to zero reasonable default value will be used.
|
||||
ReadBufferSize int
|
||||
|
||||
// WriteBufferSize is a parameter that is used for raw websocket Upgrader.
|
||||
// If set to zero reasonable default value will be used.
|
||||
WriteBufferSize int
|
||||
|
||||
// MessageSizeLimit sets the maximum size in bytes of allowed message from client.
|
||||
// By default DefaultWebsocketMessageSizeLimit will be used.
|
||||
MessageSizeLimit int
|
||||
|
||||
// CheckOrigin func to provide custom origin check logic,
|
||||
// zero value means same host check.
|
||||
CheckOrigin func(r *http.Request) bool
|
||||
|
||||
// PingInterval sets interval server will send ping messages to clients.
|
||||
// By default DefaultWebsocketPingInterval will be used.
|
||||
PingInterval time.Duration
|
||||
}
|
||||
|
||||
// NewHandler creates new Handler.
|
||||
func NewHandler(managedStreamRunner *managedstream.Runner, c Config) *Handler {
|
||||
if c.CheckOrigin == nil {
|
||||
c.CheckOrigin = sameHostOriginCheck()
|
||||
}
|
||||
upgrade := &websocket.Upgrader{
|
||||
ReadBufferSize: c.ReadBufferSize,
|
||||
WriteBufferSize: c.WriteBufferSize,
|
||||
CheckOrigin: c.CheckOrigin,
|
||||
}
|
||||
return &Handler{
|
||||
managedStreamRunner: managedStreamRunner,
|
||||
config: c,
|
||||
upgrade: upgrade,
|
||||
converter: convert.NewConverter(),
|
||||
}
|
||||
}
|
||||
|
||||
func sameHostOriginCheck() func(r *http.Request) bool {
|
||||
return func(r *http.Request) bool {
|
||||
err := checkSameHost(r)
|
||||
if err != nil {
|
||||
logger.Warn("Origin check failure", "origin", r.Header.Get("origin"), "error", err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func checkSameHost(r *http.Request) error {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return nil
|
||||
}
|
||||
u, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
|
||||
}
|
||||
if strings.EqualFold(r.Host, u.Host) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
|
||||
}
|
||||
|
||||
// Defaults.
|
||||
const (
|
||||
DefaultWebsocketPingInterval = 25 * time.Second
|
||||
DefaultWebsocketMessageSizeLimit = 1024 * 1024 // 1MB
|
||||
)
|
||||
|
||||
func (s *Handler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
streamID, ok := livecontext.GetContextStreamID(r.Context())
|
||||
if !ok || streamID == "" {
|
||||
logger.Warn("Push request without stream ID")
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := s.upgrade.Upgrade(rw, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pingInterval := s.config.PingInterval
|
||||
if pingInterval == 0 {
|
||||
pingInterval = DefaultWebsocketPingInterval
|
||||
}
|
||||
messageSizeLimit := s.config.MessageSizeLimit
|
||||
if messageSizeLimit == 0 {
|
||||
messageSizeLimit = DefaultWebsocketMessageSizeLimit
|
||||
}
|
||||
|
||||
if messageSizeLimit > 0 {
|
||||
conn.SetReadLimit(int64(messageSizeLimit))
|
||||
}
|
||||
if pingInterval > 0 {
|
||||
pongWait := pingInterval * 10 / 9
|
||||
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
conn.SetPongHandler(func(string) error {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(25 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
deadline := time.Now().Add(pingInterval / 2)
|
||||
err := conn.WriteControl(websocket.PingMessage, nil, deadline)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
user, ok := livecontext.GetContextSignedUser(r.Context())
|
||||
if !ok {
|
||||
logger.Error("No user found in context")
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
_, body, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
stream, err := s.managedStreamRunner.GetOrCreateStream(user.OrgId, liveDto.ScopeStream, streamID)
|
||||
if err != nil {
|
||||
logger.Error("Error getting stream", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO Grafana 8: decide which formats to use or keep all.
|
||||
urlValues := r.URL.Query()
|
||||
frameFormat := pushurl.FrameFormatFromValues(urlValues)
|
||||
|
||||
logger.Debug("Live Push request",
|
||||
"protocol", "http",
|
||||
"streamId", streamID,
|
||||
"bodyLength", len(body),
|
||||
"frameFormat", frameFormat,
|
||||
)
|
||||
|
||||
metricFrames, err := s.converter.Convert(body, frameFormat)
|
||||
if err != nil {
|
||||
logger.Error("Error converting metrics", "error", err, "frameFormat", frameFormat)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, mf := range metricFrames {
|
||||
err := stream.Push(mf.Key(), mf.Frame())
|
||||
if err != nil {
|
||||
logger.Error("Error pushing frame", "error", err, "data", string(body))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
82
pkg/services/live/pushws/push_pipeline.go
Normal file
82
pkg/services/live/pushws/push_pipeline.go
Normal file
@ -0,0 +1,82 @@
|
||||
package pushws
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/live/convert"
|
||||
"github.com/grafana/grafana/pkg/services/live/livecontext"
|
||||
"github.com/grafana/grafana/pkg/services/live/pipeline"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// PipelinePushHandler handles WebSocket client connections that push data to Live Pipeline.
|
||||
type PipelinePushHandler struct {
|
||||
pipeline *pipeline.Pipeline
|
||||
config Config
|
||||
upgrade *websocket.Upgrader
|
||||
converter *convert.Converter
|
||||
}
|
||||
|
||||
// NewPathHandler creates new PipelinePushHandler.
|
||||
func NewPipelinePushHandler(pipeline *pipeline.Pipeline, c Config) *PipelinePushHandler {
|
||||
if c.CheckOrigin == nil {
|
||||
c.CheckOrigin = sameHostOriginCheck()
|
||||
}
|
||||
upgrade := &websocket.Upgrader{
|
||||
ReadBufferSize: c.ReadBufferSize,
|
||||
WriteBufferSize: c.WriteBufferSize,
|
||||
CheckOrigin: c.CheckOrigin,
|
||||
}
|
||||
return &PipelinePushHandler{
|
||||
pipeline: pipeline,
|
||||
config: c,
|
||||
upgrade: upgrade,
|
||||
converter: convert.NewConverter(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PipelinePushHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
channelID, ok := livecontext.GetContextChannelID(r.Context())
|
||||
if !ok || channelID == "" {
|
||||
logger.Warn("Push request without channel")
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := livecontext.GetContextSignedUser(r.Context())
|
||||
if !ok {
|
||||
logger.Error("No user found in context")
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := s.upgrade.Upgrade(rw, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
setupWSConn(r.Context(), conn, s.config)
|
||||
|
||||
for {
|
||||
_, body, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
logger.Debug("Live channel push request",
|
||||
"protocol", "http",
|
||||
"channel", channelID,
|
||||
"bodyLength", len(body),
|
||||
)
|
||||
|
||||
ruleFound, err := s.pipeline.ProcessInput(r.Context(), user.OrgId, channelID, body)
|
||||
if err != nil {
|
||||
logger.Error("Pipeline input processing error", "error", err, "body", string(body))
|
||||
return
|
||||
}
|
||||
if !ruleFound {
|
||||
logger.Error("No conversion rule for a channel", "error", err, "channel", channelID)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
99
pkg/services/live/pushws/push_stream.go
Normal file
99
pkg/services/live/pushws/push_stream.go
Normal file
@ -0,0 +1,99 @@
|
||||
package pushws
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/live/convert"
|
||||
"github.com/grafana/grafana/pkg/services/live/livecontext"
|
||||
"github.com/grafana/grafana/pkg/services/live/managedstream"
|
||||
"github.com/grafana/grafana/pkg/services/live/pushurl"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
liveDto "github.com/grafana/grafana-plugin-sdk-go/live"
|
||||
)
|
||||
|
||||
// Handler handles WebSocket client connections that push data to Live.
|
||||
type Handler struct {
|
||||
managedStreamRunner *managedstream.Runner
|
||||
config Config
|
||||
upgrade *websocket.Upgrader
|
||||
converter *convert.Converter
|
||||
}
|
||||
|
||||
// NewHandler creates new Handler.
|
||||
func NewHandler(managedStreamRunner *managedstream.Runner, c Config) *Handler {
|
||||
if c.CheckOrigin == nil {
|
||||
c.CheckOrigin = sameHostOriginCheck()
|
||||
}
|
||||
upgrade := &websocket.Upgrader{
|
||||
ReadBufferSize: c.ReadBufferSize,
|
||||
WriteBufferSize: c.WriteBufferSize,
|
||||
CheckOrigin: c.CheckOrigin,
|
||||
}
|
||||
return &Handler{
|
||||
managedStreamRunner: managedStreamRunner,
|
||||
config: c,
|
||||
upgrade: upgrade,
|
||||
converter: convert.NewConverter(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Handler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
streamID, ok := livecontext.GetContextStreamID(r.Context())
|
||||
if !ok || streamID == "" {
|
||||
logger.Warn("Push request without stream ID")
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := livecontext.GetContextSignedUser(r.Context())
|
||||
if !ok {
|
||||
logger.Error("No user found in context")
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := s.upgrade.Upgrade(rw, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
setupWSConn(r.Context(), conn, s.config)
|
||||
|
||||
for {
|
||||
_, body, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
stream, err := s.managedStreamRunner.GetOrCreateStream(user.OrgId, liveDto.ScopeStream, streamID)
|
||||
if err != nil {
|
||||
logger.Error("Error getting stream", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO Grafana 8: decide which formats to use or keep all.
|
||||
urlValues := r.URL.Query()
|
||||
frameFormat := pushurl.FrameFormatFromValues(urlValues)
|
||||
|
||||
logger.Debug("Live Push request",
|
||||
"protocol", "http",
|
||||
"streamId", streamID,
|
||||
"bodyLength", len(body),
|
||||
"frameFormat", frameFormat,
|
||||
)
|
||||
|
||||
metricFrames, err := s.converter.Convert(body, frameFormat)
|
||||
if err != nil {
|
||||
logger.Error("Error converting metrics", "error", err, "frameFormat", frameFormat)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, mf := range metricFrames {
|
||||
err := stream.Push(mf.Key(), mf.Frame())
|
||||
if err != nil {
|
||||
logger.Error("Error pushing frame", "error", err, "data", string(body))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
113
pkg/services/live/pushws/ws.go
Normal file
113
pkg/services/live/pushws/ws.go
Normal file
@ -0,0 +1,113 @@
|
||||
package pushws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = log.New("live.push_ws")
|
||||
)
|
||||
|
||||
// Config represents config for Handler.
|
||||
type Config struct {
|
||||
// ReadBufferSize is a parameter that is used for raw websocket Upgrader.
|
||||
// If set to zero reasonable default value will be used.
|
||||
ReadBufferSize int
|
||||
|
||||
// WriteBufferSize is a parameter that is used for raw websocket Upgrader.
|
||||
// If set to zero reasonable default value will be used.
|
||||
WriteBufferSize int
|
||||
|
||||
// MessageSizeLimit sets the maximum size in bytes of allowed message from client.
|
||||
// By default DefaultWebsocketMessageSizeLimit will be used.
|
||||
MessageSizeLimit int
|
||||
|
||||
// CheckOrigin func to provide custom origin check logic,
|
||||
// zero value means same host check.
|
||||
CheckOrigin func(r *http.Request) bool
|
||||
|
||||
// PingInterval sets interval server will send ping messages to clients.
|
||||
// By default DefaultWebsocketPingInterval will be used.
|
||||
PingInterval time.Duration
|
||||
}
|
||||
|
||||
func sameHostOriginCheck() func(r *http.Request) bool {
|
||||
return func(r *http.Request) bool {
|
||||
err := checkSameHost(r)
|
||||
if err != nil {
|
||||
logger.Warn("Origin check failure", "origin", r.Header.Get("origin"), "error", err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func checkSameHost(r *http.Request) error {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return nil
|
||||
}
|
||||
u, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
|
||||
}
|
||||
if strings.EqualFold(r.Host, u.Host) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
|
||||
}
|
||||
|
||||
// Defaults.
|
||||
const (
|
||||
DefaultWebsocketPingInterval = 25 * time.Second
|
||||
DefaultWebsocketMessageSizeLimit = 1024 * 1024 // 1MB
|
||||
)
|
||||
|
||||
func setupWSConn(ctx context.Context, conn *websocket.Conn, config Config) {
|
||||
pingInterval := config.PingInterval
|
||||
if pingInterval == 0 {
|
||||
pingInterval = DefaultWebsocketPingInterval
|
||||
}
|
||||
messageSizeLimit := config.MessageSizeLimit
|
||||
if messageSizeLimit == 0 {
|
||||
messageSizeLimit = DefaultWebsocketMessageSizeLimit
|
||||
}
|
||||
|
||||
if messageSizeLimit > 0 {
|
||||
conn.SetReadLimit(int64(messageSizeLimit))
|
||||
}
|
||||
if pingInterval > 0 {
|
||||
pongWait := pingInterval * 10 / 9
|
||||
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
conn.SetPongHandler(func(string) error {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(25 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
deadline := time.Now().Add(pingInterval / 2)
|
||||
err := conn.WriteControl(websocket.PingMessage, nil, deadline)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
Loading…
Reference in New Issue
Block a user