2022-02-09 06:44:38 -06:00
|
|
|
package middleware
|
|
|
|
|
|
|
|
import (
|
|
|
|
"errors"
|
|
|
|
"net/http"
|
|
|
|
"net/url"
|
2022-02-28 12:58:56 -06:00
|
|
|
|
|
|
|
"github.com/grafana/grafana/pkg/infra/log"
|
|
|
|
"github.com/grafana/grafana/pkg/util"
|
2022-02-09 06:44:38 -06:00
|
|
|
)
|
|
|
|
|
2022-02-28 12:58:56 -06:00
|
|
|
func CSRF(loginCookieName string, logger log.Logger) func(http.Handler) http.Handler {
|
2022-02-09 06:44:38 -06:00
|
|
|
// As per RFC 7231/4.2.2 these methods are idempotent:
|
|
|
|
// (GET is excluded because it may have side effects in some APIs)
|
|
|
|
safeMethods := []string{"HEAD", "OPTIONS", "TRACE"}
|
|
|
|
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
// If request has no login cookie - skip CSRF checks
|
|
|
|
if _, err := r.Cookie(loginCookieName); errors.Is(err, http.ErrNoCookie) {
|
|
|
|
next.ServeHTTP(w, r)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
// Skip CSRF checks for "safe" methods
|
|
|
|
for _, method := range safeMethods {
|
|
|
|
if r.Method == method {
|
|
|
|
next.ServeHTTP(w, r)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Otherwise - verify that Origin matches the server origin
|
2022-02-28 12:58:56 -06:00
|
|
|
netAddr, err := util.SplitHostPortDefault(r.Host, "", "0") // we ignore the port
|
|
|
|
if err != nil {
|
|
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2022-02-09 06:44:38 -06:00
|
|
|
origin, err := url.Parse(r.Header.Get("Origin"))
|
2022-02-28 12:58:56 -06:00
|
|
|
if err != nil {
|
|
|
|
logger.Error("error parsing Origin header", "err", err)
|
|
|
|
}
|
|
|
|
if err != nil || netAddr.Host == "" || (origin.String() != "" && origin.Hostname() != netAddr.Host) {
|
2022-02-09 06:44:38 -06:00
|
|
|
http.Error(w, "origin not allowed", http.StatusForbidden)
|
|
|
|
return
|
|
|
|
}
|
2022-02-28 12:58:56 -06:00
|
|
|
|
2022-02-09 06:44:38 -06:00
|
|
|
next.ServeHTTP(w, r)
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|