origin checker refactor (#7889)

This commit is contained in:
Chris
2017-11-22 15:58:03 -06:00
committed by Christopher Speller
parent 77a1dc1f2f
commit 1ccf093803
4 changed files with 17 additions and 18 deletions

View File

@@ -18,12 +18,10 @@ func (api *API) InitWebSocket() {
}
func connect(c *Context, w http.ResponseWriter, r *http.Request) {
originChecker := utils.GetOriginChecker(r)
upgrader := websocket.Upgrader{
ReadBufferSize: model.SOCKET_MAX_MESSAGE_SIZE_KB,
WriteBufferSize: model.SOCKET_MAX_MESSAGE_SIZE_KB,
CheckOrigin: originChecker,
CheckOrigin: c.App.OriginChecker(),
}
ws, err := upgrader.Upgrade(w, r, nil)

View File

@@ -19,12 +19,10 @@ func (api *API) InitWebSocket() {
}
func connectWebSocket(c *Context, w http.ResponseWriter, r *http.Request) {
originChecker := utils.GetOriginChecker(r)
upgrader := websocket.Upgrader{
ReadBufferSize: model.SOCKET_MAX_MESSAGE_SIZE_KB,
WriteBufferSize: model.SOCKET_MAX_MESSAGE_SIZE_KB,
CheckOrigin: originChecker,
CheckOrigin: c.App.OriginChecker(),
}
ws, err := upgrader.Upgrade(w, r, nil)

View File

@@ -58,8 +58,8 @@ type CorsWrapper struct {
}
func (cw *CorsWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if len(*cw.config().ServiceSettings.AllowCorsFrom) > 0 {
if utils.OriginChecker(r) {
if allowed := *cw.config().ServiceSettings.AllowCorsFrom; allowed != "" {
if utils.CheckOrigin(r, allowed) {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
if r.Method == "OPTIONS" {
@@ -252,6 +252,13 @@ func (a *App) StopServer() {
}
}
func (a *App) OriginChecker() func(*http.Request) bool {
if allowed := *a.Config().ServiceSettings.AllowCorsFrom; allowed != "" {
return utils.OriginChecker(allowed)
}
return nil
}
// This is required to re-use the underlying connection and not take up file descriptors
func consumeAndClose(r *http.Response) {
if r.Body != nil {

View File

@@ -11,14 +11,12 @@ import (
"github.com/mattermost/mattermost-server/model"
)
type OriginCheckerProc func(*http.Request) bool
func OriginChecker(r *http.Request) bool {
func CheckOrigin(r *http.Request, allowedOrigins string) bool {
origin := r.Header.Get("Origin")
if *Cfg.ServiceSettings.AllowCorsFrom == "*" {
if allowedOrigins == "*" {
return true
}
for _, allowed := range strings.Split(*Cfg.ServiceSettings.AllowCorsFrom, " ") {
for _, allowed := range strings.Split(allowedOrigins, " ") {
if allowed == origin {
return true
}
@@ -26,12 +24,10 @@ func OriginChecker(r *http.Request) bool {
return false
}
func GetOriginChecker(r *http.Request) OriginCheckerProc {
if len(*Cfg.ServiceSettings.AllowCorsFrom) > 0 {
return OriginChecker
func OriginChecker(allowedOrigins string) func(*http.Request) bool {
return func(r *http.Request) bool {
return CheckOrigin(r, allowedOrigins)
}
return nil
}
func RenderWebError(err *model.AppError, w http.ResponseWriter, r *http.Request) {