Websocket CORS Support (#5667)

* Second attept at patching api/websocket.go for CORS support.

* Missing include

* Fixed whitespace formatting so that gofmt passes.

* Added tests for CORS filtering
This commit is contained in:
Brad Howes
2017-03-23 14:10:52 +01:00
committed by Christopher Speller
parent 34cb70d005
commit 120f5a6f8a
2 changed files with 46 additions and 1 deletions

View File

@@ -5,6 +5,7 @@ package api
import (
"net/http"
"strings"
l4g "github.com/alecthomas/log4go"
"github.com/gorilla/websocket"
@@ -19,11 +20,25 @@ func InitWebSocket() {
app.HubStart()
}
type OriginCheckerProc func(*http.Request) bool
func OriginChecker(r *http.Request) bool {
origin := r.Header.Get("Origin")
return *utils.Cfg.ServiceSettings.AllowCorsFrom == "*" || strings.Contains(origin, *utils.Cfg.ServiceSettings.AllowCorsFrom)
}
func connect(c *Context, w http.ResponseWriter, r *http.Request) {
var originChecker OriginCheckerProc = nil
if len(*utils.Cfg.ServiceSettings.AllowCorsFrom) > 0 {
originChecker = OriginChecker
}
upgrader := websocket.Upgrader{
ReadBufferSize: model.SOCKET_MAX_MESSAGE_SIZE_KB,
WriteBufferSize: model.SOCKET_MAX_MESSAGE_SIZE_KB,
CheckOrigin: nil,
CheckOrigin: originChecker,
}
ws, err := upgrader.Upgrade(w, r, nil)

View File

@@ -316,6 +316,7 @@ func TestCreateDirectChannelWithSocket(t *testing.T) {
func TestWebsocketOriginSecurity(t *testing.T) {
Setup().InitBasic()
url := "ws://localhost" + utils.Cfg.ServiceSettings.ListenAddress
// Should fail because origin doesn't match
@@ -333,6 +334,35 @@ func TestWebsocketOriginSecurity(t *testing.T) {
if err != nil {
t.Fatal(err)
}
// Should succeed now because open CORS
*utils.Cfg.ServiceSettings.AllowCorsFrom = "*"
_, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX_V3+"/users/websocket", http.Header{
"Origin": []string{"http://www.evil.com"},
})
if err != nil {
t.Fatal(err)
}
// Should succeed now because matching CORS
*utils.Cfg.ServiceSettings.AllowCorsFrom = "www.evil.com"
_, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX_V3+"/users/websocket", http.Header{
"Origin": []string{"http://www.evil.com"},
})
if err != nil {
t.Fatal(err)
}
// Should fail because non-matching CORS
*utils.Cfg.ServiceSettings.AllowCorsFrom = "www.good.com"
_, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX_V3+"/users/websocket", http.Header{
"Origin": []string{"http://www.evil.com"},
})
if err == nil {
t.Fatal("Should have errored because Origin contain AllowCorsFrom")
}
*utils.Cfg.ServiceSettings.AllowCorsFrom = ""
}
func TestZZWebSocketTearDown(t *testing.T) {