mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
Websocket CORS Support (#5667)
* Second attept at patching api/websocket.go for CORS support. * Missing include * Fixed whitespace formatting so that gofmt passes. * Added tests for CORS filtering
This commit is contained in:
committed by
Christopher Speller
parent
34cb70d005
commit
120f5a6f8a
@@ -5,6 +5,7 @@ package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
l4g "github.com/alecthomas/log4go"
|
||||
"github.com/gorilla/websocket"
|
||||
@@ -19,11 +20,25 @@ func InitWebSocket() {
|
||||
app.HubStart()
|
||||
}
|
||||
|
||||
type OriginCheckerProc func(*http.Request) bool
|
||||
|
||||
func OriginChecker(r *http.Request) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
return *utils.Cfg.ServiceSettings.AllowCorsFrom == "*" || strings.Contains(origin, *utils.Cfg.ServiceSettings.AllowCorsFrom)
|
||||
}
|
||||
|
||||
func connect(c *Context, w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var originChecker OriginCheckerProc = nil
|
||||
|
||||
if len(*utils.Cfg.ServiceSettings.AllowCorsFrom) > 0 {
|
||||
originChecker = OriginChecker
|
||||
}
|
||||
|
||||
upgrader := websocket.Upgrader{
|
||||
ReadBufferSize: model.SOCKET_MAX_MESSAGE_SIZE_KB,
|
||||
WriteBufferSize: model.SOCKET_MAX_MESSAGE_SIZE_KB,
|
||||
CheckOrigin: nil,
|
||||
CheckOrigin: originChecker,
|
||||
}
|
||||
|
||||
ws, err := upgrader.Upgrade(w, r, nil)
|
||||
|
||||
@@ -316,6 +316,7 @@ func TestCreateDirectChannelWithSocket(t *testing.T) {
|
||||
|
||||
func TestWebsocketOriginSecurity(t *testing.T) {
|
||||
Setup().InitBasic()
|
||||
|
||||
url := "ws://localhost" + utils.Cfg.ServiceSettings.ListenAddress
|
||||
|
||||
// Should fail because origin doesn't match
|
||||
@@ -333,6 +334,35 @@ func TestWebsocketOriginSecurity(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should succeed now because open CORS
|
||||
*utils.Cfg.ServiceSettings.AllowCorsFrom = "*"
|
||||
_, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX_V3+"/users/websocket", http.Header{
|
||||
"Origin": []string{"http://www.evil.com"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should succeed now because matching CORS
|
||||
*utils.Cfg.ServiceSettings.AllowCorsFrom = "www.evil.com"
|
||||
_, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX_V3+"/users/websocket", http.Header{
|
||||
"Origin": []string{"http://www.evil.com"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should fail because non-matching CORS
|
||||
*utils.Cfg.ServiceSettings.AllowCorsFrom = "www.good.com"
|
||||
_, _, err = websocket.DefaultDialer.Dial(url+model.API_URL_SUFFIX_V3+"/users/websocket", http.Header{
|
||||
"Origin": []string{"http://www.evil.com"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Should have errored because Origin contain AllowCorsFrom")
|
||||
}
|
||||
|
||||
*utils.Cfg.ServiceSettings.AllowCorsFrom = ""
|
||||
}
|
||||
|
||||
func TestZZWebSocketTearDown(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user