mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
API: Migrate CSRF to service and support additional options (#48120)
* API: Migrate CSRF to service and support additional options * minor * public Csrf service to use in tests * WIP * remove fmt * comment * WIP * remove fmt prints * todo add prefix slash * remove fmt prints * linting fix * remove trimPrefix Co-authored-by: Eric Leijonmarck <eric.leijonmarck@gmail.com> Co-authored-by: IevaVasiljeva <ieva.vasiljeva@grafana.com>
This commit is contained in:
@@ -13,6 +13,11 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/grafana/grafana/pkg/middleware/csrf"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
|
||||
"github.com/grafana/grafana/pkg/api/avatar"
|
||||
"github.com/grafana/grafana/pkg/api/routing"
|
||||
httpstatic "github.com/grafana/grafana/pkg/api/static"
|
||||
@@ -74,8 +79,6 @@ import (
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util/errutil"
|
||||
"github.com/grafana/grafana/pkg/web"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
type HTTPServer struct {
|
||||
@@ -152,6 +155,7 @@ type HTTPServer struct {
|
||||
PluginSettings *pluginSettings.Service
|
||||
AvatarCacheServer *avatar.AvatarCacheServer
|
||||
preferenceService pref.Service
|
||||
Csrf csrf.Service
|
||||
entityEventsService store.EntityEventsService
|
||||
folderPermissionsService accesscontrol.FolderPermissionsService
|
||||
dashboardPermissionsService accesscontrol.DashboardPermissionsService
|
||||
@@ -192,7 +196,7 @@ func ProvideHTTPServer(opts ServerOptions, cfg *setting.Cfg, routeRegister routi
|
||||
avatarCacheServer *avatar.AvatarCacheServer, preferenceService pref.Service, entityEventsService store.EntityEventsService,
|
||||
teamsPermissionsService accesscontrol.TeamPermissionsService, folderPermissionsService accesscontrol.FolderPermissionsService,
|
||||
dashboardPermissionsService accesscontrol.DashboardPermissionsService, dashboardVersionService dashver.Service,
|
||||
starService star.Service, coremodelRegistry *coremodel.Registry,
|
||||
starService star.Service, coremodelRegistry *coremodel.Registry, csrfService csrf.Service,
|
||||
) (*HTTPServer, error) {
|
||||
web.Env = cfg.Env
|
||||
m := web.New()
|
||||
@@ -266,6 +270,7 @@ func ProvideHTTPServer(opts ServerOptions, cfg *setting.Cfg, routeRegister routi
|
||||
PluginSettings: pluginSettings,
|
||||
AvatarCacheServer: avatarCacheServer,
|
||||
preferenceService: preferenceService,
|
||||
Csrf: csrfService,
|
||||
entityEventsService: entityEventsService,
|
||||
folderPermissionsService: folderPermissionsService,
|
||||
dashboardPermissionsService: dashboardPermissionsService,
|
||||
@@ -499,7 +504,7 @@ func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() {
|
||||
}
|
||||
|
||||
m.Use(middleware.Recovery(hs.Cfg))
|
||||
m.UseMiddleware(middleware.CSRF(hs.Cfg.LoginCookieName, hs.log))
|
||||
m.UseMiddleware(hs.Csrf.Middleware(hs.log))
|
||||
|
||||
hs.mapStatic(m, hs.Cfg.StaticRootPath, "build", "public/build")
|
||||
hs.mapStatic(m, hs.Cfg.StaticRootPath, "", "public", "/public/views/swagger.html")
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
)
|
||||
|
||||
func CSRF(loginCookieName string, logger log.Logger) func(http.Handler) http.Handler {
|
||||
// As per RFC 7231/4.2.2 these methods are idempotent:
|
||||
// (GET is excluded because it may have side effects in some APIs)
|
||||
safeMethods := []string{"HEAD", "OPTIONS", "TRACE"}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// If request has no login cookie - skip CSRF checks
|
||||
if _, err := r.Cookie(loginCookieName); errors.Is(err, http.ErrNoCookie) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
// Skip CSRF checks for "safe" methods
|
||||
for _, method := range safeMethods {
|
||||
if r.Method == method {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
// Otherwise - verify that Origin matches the server origin
|
||||
netAddr, err := util.SplitHostPortDefault(r.Host, "", "0") // we ignore the port
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
origin, err := url.Parse(r.Header.Get("Origin"))
|
||||
if err != nil {
|
||||
logger.Error("error parsing Origin header", "err", err)
|
||||
}
|
||||
if err != nil || netAddr.Host == "" || (origin.String() != "" && origin.Hostname() != netAddr.Host) {
|
||||
http.Error(w, "origin not allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
130
pkg/middleware/csrf/csrf.go
Normal file
130
pkg/middleware/csrf/csrf.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package csrf
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
Middleware(logger log.Logger) func(http.Handler) http.Handler
|
||||
TrustOrigin(origin string)
|
||||
AddOriginHeader(headerName string)
|
||||
AddSafeEndpoint(endpoint string)
|
||||
}
|
||||
|
||||
type Implementation struct {
|
||||
cfg *setting.Cfg
|
||||
|
||||
trustedOrigins map[string]struct{}
|
||||
originHeaders map[string]struct{}
|
||||
safeEndpoints map[string]struct{}
|
||||
}
|
||||
|
||||
func ProvideCSRFFilter(cfg *setting.Cfg) Service {
|
||||
i := &Implementation{
|
||||
cfg: cfg,
|
||||
trustedOrigins: map[string]struct{}{},
|
||||
originHeaders: map[string]struct{}{
|
||||
"Origin": {},
|
||||
},
|
||||
safeEndpoints: map[string]struct{}{},
|
||||
}
|
||||
|
||||
additionalHeaders := cfg.SectionWithEnvOverrides("security").Key("csrf_additional_headers").Strings(" ")
|
||||
trustedOrigins := cfg.SectionWithEnvOverrides("security").Key("csrf_trusted_origins").Strings(" ")
|
||||
|
||||
for _, header := range additionalHeaders {
|
||||
i.originHeaders[header] = struct{}{}
|
||||
}
|
||||
for _, origin := range trustedOrigins {
|
||||
i.trustedOrigins[origin] = struct{}{}
|
||||
}
|
||||
|
||||
return i
|
||||
}
|
||||
|
||||
func (i *Implementation) Middleware(logger log.Logger) func(http.Handler) http.Handler {
|
||||
// As per RFC 7231/4.2.2 these methods are idempotent:
|
||||
// (GET is excluded because it may have side effects in some APIs)
|
||||
safeMethods := []string{"HEAD", "OPTIONS", "TRACE"}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// If request has no login cookie - skip CSRF checks
|
||||
if _, err := r.Cookie(i.cfg.LoginCookieName); errors.Is(err, http.ErrNoCookie) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
// Skip CSRF checks for "safe" methods
|
||||
for _, method := range safeMethods {
|
||||
if r.Method == method {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
// Skip CSRF checks for "safe" endpoints
|
||||
for safeEndpoint := range i.safeEndpoints {
|
||||
if r.URL.Path == safeEndpoint {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
// Otherwise - verify that Origin matches the server origin
|
||||
netAddr, err := util.SplitHostPortDefault(r.Host, "", "0") // we ignore the port
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
origins := map[string]struct{}{}
|
||||
for header := range i.originHeaders {
|
||||
origin, err := url.Parse(r.Header.Get(header))
|
||||
if err != nil {
|
||||
logger.Error("error parsing Origin header", "header", header, "err", err)
|
||||
}
|
||||
if origin.String() != "" {
|
||||
origins[origin.Hostname()] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// No Origin header sent, skip CSRF check.
|
||||
if len(origins) == 0 {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
trustedOrigin := false
|
||||
for o := range i.trustedOrigins {
|
||||
if _, ok := origins[o]; ok {
|
||||
trustedOrigin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
_, hostnameMatches := origins[netAddr.Host]
|
||||
if netAddr.Host == "" || !trustedOrigin && !hostnameMatches {
|
||||
http.Error(w, "origin not allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Implementation) TrustOrigin(origin string) {
|
||||
i.trustedOrigins[origin] = struct{}{}
|
||||
}
|
||||
|
||||
func (i *Implementation) AddOriginHeader(headerName string) {
|
||||
i.originHeaders[headerName] = struct{}{}
|
||||
}
|
||||
|
||||
// AddSafeEndpoint is used for endpoints requests to skip CSRF check
|
||||
func (i *Implementation) AddSafeEndpoint(endpoint string) {
|
||||
i.safeEndpoints[endpoint] = struct{}{}
|
||||
}
|
||||
@@ -1,12 +1,14 @@
|
||||
package middleware
|
||||
package csrf
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
)
|
||||
|
||||
func TestMiddlewareCSRF(t *testing.T) {
|
||||
@@ -118,7 +120,10 @@ func csrfScenario(t *testing.T, cookieName, method, origin, host string) *httpte
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler := CSRF(cookieName, log.New())(testHandler)
|
||||
cfg := setting.NewCfg()
|
||||
cfg.LoginCookieName = cookieName
|
||||
service := ProvideCSRFFilter(cfg)
|
||||
handler := service.Middleware(log.New())(testHandler)
|
||||
handler.ServeHTTP(rr, req)
|
||||
return rr
|
||||
}
|
||||
@@ -6,7 +6,6 @@ package server
|
||||
import (
|
||||
"github.com/google/wire"
|
||||
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
|
||||
|
||||
"github.com/grafana/grafana/pkg/api"
|
||||
"github.com/grafana/grafana/pkg/api/avatar"
|
||||
"github.com/grafana/grafana/pkg/api/routing"
|
||||
@@ -28,6 +27,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/infra/usagestats/statscollector"
|
||||
loginpkg "github.com/grafana/grafana/pkg/login"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/middleware/csrf"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/plugins/backendplugin/coreplugin"
|
||||
@@ -254,6 +254,7 @@ var wireBasicSet = wire.NewSet(
|
||||
cmreg.ProvideRegistry,
|
||||
cuectx.ProvideCUEContext,
|
||||
cuectx.ProvideThemaLibrary,
|
||||
csrf.ProvideCSRFFilter,
|
||||
ossaccesscontrol.ProvideTeamPermissions,
|
||||
wire.Bind(new(accesscontrol.TeamPermissionsService), new(*ossaccesscontrol.TeamPermissionsService)),
|
||||
ossaccesscontrol.ProvideFolderPermissions,
|
||||
|
||||
Reference in New Issue
Block a user