API: Migrate CSRF to service and support additional options (#48120)

* API: Migrate CSRF to service and support additional options

* minor

* public Csrf service to use in tests

* WIP

* remove fmt

* comment

* WIP

* remove fmt prints

* todo add prefix slash

* remove fmt prints

* linting fix

* remove trimPrefix

Co-authored-by: Eric Leijonmarck <eric.leijonmarck@gmail.com>
Co-authored-by: IevaVasiljeva <ieva.vasiljeva@grafana.com>
This commit is contained in:
Emil Tullstedt
2022-06-02 15:52:30 +02:00
committed by GitHub
parent 84860ffc96
commit 3e81fa0716
5 changed files with 149 additions and 58 deletions

View File

@@ -13,6 +13,11 @@ import (
"strings"
"sync"
"github.com/grafana/grafana/pkg/middleware/csrf"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/grafana/grafana/pkg/api/avatar"
"github.com/grafana/grafana/pkg/api/routing"
httpstatic "github.com/grafana/grafana/pkg/api/static"
@@ -74,8 +79,6 @@ import (
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util/errutil"
"github.com/grafana/grafana/pkg/web"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
type HTTPServer struct {
@@ -152,6 +155,7 @@ type HTTPServer struct {
PluginSettings *pluginSettings.Service
AvatarCacheServer *avatar.AvatarCacheServer
preferenceService pref.Service
Csrf csrf.Service
entityEventsService store.EntityEventsService
folderPermissionsService accesscontrol.FolderPermissionsService
dashboardPermissionsService accesscontrol.DashboardPermissionsService
@@ -192,7 +196,7 @@ func ProvideHTTPServer(opts ServerOptions, cfg *setting.Cfg, routeRegister routi
avatarCacheServer *avatar.AvatarCacheServer, preferenceService pref.Service, entityEventsService store.EntityEventsService,
teamsPermissionsService accesscontrol.TeamPermissionsService, folderPermissionsService accesscontrol.FolderPermissionsService,
dashboardPermissionsService accesscontrol.DashboardPermissionsService, dashboardVersionService dashver.Service,
starService star.Service, coremodelRegistry *coremodel.Registry,
starService star.Service, coremodelRegistry *coremodel.Registry, csrfService csrf.Service,
) (*HTTPServer, error) {
web.Env = cfg.Env
m := web.New()
@@ -266,6 +270,7 @@ func ProvideHTTPServer(opts ServerOptions, cfg *setting.Cfg, routeRegister routi
PluginSettings: pluginSettings,
AvatarCacheServer: avatarCacheServer,
preferenceService: preferenceService,
Csrf: csrfService,
entityEventsService: entityEventsService,
folderPermissionsService: folderPermissionsService,
dashboardPermissionsService: dashboardPermissionsService,
@@ -499,7 +504,7 @@ func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() {
}
m.Use(middleware.Recovery(hs.Cfg))
m.UseMiddleware(middleware.CSRF(hs.Cfg.LoginCookieName, hs.log))
m.UseMiddleware(hs.Csrf.Middleware(hs.log))
hs.mapStatic(m, hs.Cfg.StaticRootPath, "build", "public/build")
hs.mapStatic(m, hs.Cfg.StaticRootPath, "", "public", "/public/views/swagger.html")

View File

@@ -1,50 +0,0 @@
package middleware
import (
"errors"
"net/http"
"net/url"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/util"
)
func CSRF(loginCookieName string, logger log.Logger) func(http.Handler) http.Handler {
// As per RFC 7231/4.2.2 these methods are idempotent:
// (GET is excluded because it may have side effects in some APIs)
safeMethods := []string{"HEAD", "OPTIONS", "TRACE"}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// If request has no login cookie - skip CSRF checks
if _, err := r.Cookie(loginCookieName); errors.Is(err, http.ErrNoCookie) {
next.ServeHTTP(w, r)
return
}
// Skip CSRF checks for "safe" methods
for _, method := range safeMethods {
if r.Method == method {
next.ServeHTTP(w, r)
return
}
}
// Otherwise - verify that Origin matches the server origin
netAddr, err := util.SplitHostPortDefault(r.Host, "", "0") // we ignore the port
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
origin, err := url.Parse(r.Header.Get("Origin"))
if err != nil {
logger.Error("error parsing Origin header", "err", err)
}
if err != nil || netAddr.Host == "" || (origin.String() != "" && origin.Hostname() != netAddr.Host) {
http.Error(w, "origin not allowed", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
}

130
pkg/middleware/csrf/csrf.go Normal file
View File

@@ -0,0 +1,130 @@
package csrf
import (
"errors"
"net/http"
"net/url"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
)
type Service interface {
Middleware(logger log.Logger) func(http.Handler) http.Handler
TrustOrigin(origin string)
AddOriginHeader(headerName string)
AddSafeEndpoint(endpoint string)
}
type Implementation struct {
cfg *setting.Cfg
trustedOrigins map[string]struct{}
originHeaders map[string]struct{}
safeEndpoints map[string]struct{}
}
func ProvideCSRFFilter(cfg *setting.Cfg) Service {
i := &Implementation{
cfg: cfg,
trustedOrigins: map[string]struct{}{},
originHeaders: map[string]struct{}{
"Origin": {},
},
safeEndpoints: map[string]struct{}{},
}
additionalHeaders := cfg.SectionWithEnvOverrides("security").Key("csrf_additional_headers").Strings(" ")
trustedOrigins := cfg.SectionWithEnvOverrides("security").Key("csrf_trusted_origins").Strings(" ")
for _, header := range additionalHeaders {
i.originHeaders[header] = struct{}{}
}
for _, origin := range trustedOrigins {
i.trustedOrigins[origin] = struct{}{}
}
return i
}
func (i *Implementation) Middleware(logger log.Logger) func(http.Handler) http.Handler {
// As per RFC 7231/4.2.2 these methods are idempotent:
// (GET is excluded because it may have side effects in some APIs)
safeMethods := []string{"HEAD", "OPTIONS", "TRACE"}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// If request has no login cookie - skip CSRF checks
if _, err := r.Cookie(i.cfg.LoginCookieName); errors.Is(err, http.ErrNoCookie) {
next.ServeHTTP(w, r)
return
}
// Skip CSRF checks for "safe" methods
for _, method := range safeMethods {
if r.Method == method {
next.ServeHTTP(w, r)
return
}
}
// Skip CSRF checks for "safe" endpoints
for safeEndpoint := range i.safeEndpoints {
if r.URL.Path == safeEndpoint {
next.ServeHTTP(w, r)
return
}
}
// Otherwise - verify that Origin matches the server origin
netAddr, err := util.SplitHostPortDefault(r.Host, "", "0") // we ignore the port
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
origins := map[string]struct{}{}
for header := range i.originHeaders {
origin, err := url.Parse(r.Header.Get(header))
if err != nil {
logger.Error("error parsing Origin header", "header", header, "err", err)
}
if origin.String() != "" {
origins[origin.Hostname()] = struct{}{}
}
}
// No Origin header sent, skip CSRF check.
if len(origins) == 0 {
next.ServeHTTP(w, r)
return
}
trustedOrigin := false
for o := range i.trustedOrigins {
if _, ok := origins[o]; ok {
trustedOrigin = true
break
}
}
_, hostnameMatches := origins[netAddr.Host]
if netAddr.Host == "" || !trustedOrigin && !hostnameMatches {
http.Error(w, "origin not allowed", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
}
func (i *Implementation) TrustOrigin(origin string) {
i.trustedOrigins[origin] = struct{}{}
}
func (i *Implementation) AddOriginHeader(headerName string) {
i.originHeaders[headerName] = struct{}{}
}
// AddSafeEndpoint is used for endpoints requests to skip CSRF check
func (i *Implementation) AddSafeEndpoint(endpoint string) {
i.safeEndpoints[endpoint] = struct{}{}
}

View File

@@ -1,12 +1,14 @@
package middleware
package csrf
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/setting"
)
func TestMiddlewareCSRF(t *testing.T) {
@@ -118,7 +120,10 @@ func csrfScenario(t *testing.T, cookieName, method, origin, host string) *httpte
})
rr := httptest.NewRecorder()
handler := CSRF(cookieName, log.New())(testHandler)
cfg := setting.NewCfg()
cfg.LoginCookieName = cookieName
service := ProvideCSRFFilter(cfg)
handler := service.Middleware(log.New())(testHandler)
handler.ServeHTTP(rr, req)
return rr
}

View File

@@ -6,7 +6,6 @@ package server
import (
"github.com/google/wire"
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/api"
"github.com/grafana/grafana/pkg/api/avatar"
"github.com/grafana/grafana/pkg/api/routing"
@@ -28,6 +27,7 @@ import (
"github.com/grafana/grafana/pkg/infra/usagestats/statscollector"
loginpkg "github.com/grafana/grafana/pkg/login"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/middleware/csrf"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/plugins/backendplugin/coreplugin"
@@ -254,6 +254,7 @@ var wireBasicSet = wire.NewSet(
cmreg.ProvideRegistry,
cuectx.ProvideCUEContext,
cuectx.ProvideThemaLibrary,
csrf.ProvideCSRFFilter,
ossaccesscontrol.ProvideTeamPermissions,
wire.Bind(new(accesscontrol.TeamPermissionsService), new(*ossaccesscontrol.TeamPermissionsService)),
ossaccesscontrol.ProvideFolderPermissions,