grafana/pkg/services/live/pushws/ws.go
2021-11-15 12:43:18 +03:00

114 lines
2.7 KiB
Go

package pushws
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/grafana/grafana/pkg/infra/log"
)
var (
logger = log.New("live.push_ws")
)
// Config represents config for Handler.
type Config struct {
// ReadBufferSize is a parameter that is used for raw websocket Upgrader.
// If set to zero reasonable default value will be used.
ReadBufferSize int
// WriteBufferSize is a parameter that is used for raw websocket Upgrader.
// If set to zero reasonable default value will be used.
WriteBufferSize int
// MessageSizeLimit sets the maximum size in bytes of allowed message from client.
// By default DefaultWebsocketMessageSizeLimit will be used.
MessageSizeLimit int
// CheckOrigin func to provide custom origin check logic,
// zero value means same host check.
CheckOrigin func(r *http.Request) bool
// PingInterval sets interval server will send ping messages to clients.
// By default DefaultWebsocketPingInterval will be used.
PingInterval time.Duration
}
func sameHostOriginCheck() func(r *http.Request) bool {
return func(r *http.Request) bool {
err := checkSameHost(r)
if err != nil {
logger.Warn("Origin check failure", "origin", r.Header.Get("origin"), "error", err)
return false
}
return true
}
}
func checkSameHost(r *http.Request) error {
origin := r.Header.Get("Origin")
if origin == "" {
return nil
}
u, err := url.Parse(origin)
if err != nil {
return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
}
if strings.EqualFold(r.Host, u.Host) {
return nil
}
return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
}
// Defaults.
const (
DefaultWebsocketPingInterval = 25 * time.Second
DefaultWebsocketMessageSizeLimit = 1024 * 1024 // 1MB
)
func setupWSConn(ctx context.Context, conn *websocket.Conn, config Config) {
pingInterval := config.PingInterval
if pingInterval == 0 {
pingInterval = DefaultWebsocketPingInterval
}
messageSizeLimit := config.MessageSizeLimit
if messageSizeLimit == 0 {
messageSizeLimit = DefaultWebsocketMessageSizeLimit
}
if messageSizeLimit > 0 {
conn.SetReadLimit(int64(messageSizeLimit))
}
if pingInterval > 0 {
pongWait := pingInterval * 10 / 9
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
conn.SetPongHandler(func(string) error {
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
}
go func() {
ticker := time.NewTicker(25 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
deadline := time.Now().Add(pingInterval / 2)
err := conn.WriteControl(websocket.PingMessage, nil, deadline)
if err != nil {
return
}
}
}
}()
}