mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
pkg/web: closure-style middlewares (#51238)
* pkg/web: closure-style middlewares Switches the middleware execution model from web.Handlers in a slice to web.Middleware. Middlewares are temporarily kept in a slice to preserve ordering, but prior to execution they are applied, forming a giant call-stack, giving granular control over the execution flow. * pkg/middleware: adapt to web.Middleware * pkg/middleware/recovery: use c.Req over req c.Req gets updated by future handlers, while req stays static. The current recovery implementation needs this newer information * pkg/web: correct middleware ordering * pkg/webtest: adapt middleware * pkg/web/hack: set w and r onto web.Context By adopting std middlewares, it may happen they invoke next(w,r) without putting their modified w,r into the web.Context, leading old-style handlers to operate on outdated fields. pkg/web now takes care of this * pkg/middleware: selectively use future context * pkg/web: accept closure-style on Use() * webtest: Middleware testing adds a utility function to web/webtest to obtain a http.ResponseWriter, http.Request and http.Handler the same as a middleware that runs would receive * *: cleanup * pkg/web: don't wrap Middleware from Router * pkg/web: require chain to write response * *: remove temp files * webtest: don't require chain write * *: cleanup
This commit is contained in:
parent
3893c46976
commit
534ece064b
@ -526,7 +526,7 @@ func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() {
|
||||
m.UseMiddleware(middleware.Gziper())
|
||||
}
|
||||
|
||||
m.Use(middleware.Recovery(hs.Cfg))
|
||||
m.UseMiddleware(middleware.Recovery(hs.Cfg))
|
||||
m.UseMiddleware(hs.Csrf.Middleware())
|
||||
|
||||
hs.mapStatic(m, hs.Cfg.StaticRootPath, "build", "public/build")
|
||||
|
@ -4,7 +4,6 @@ package response
|
||||
//NOTE: This file belongs into pkg/web, but due to cyclic imports that are hard to resolve at the current time, it temporarily lives here.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
@ -24,23 +23,25 @@ type (
|
||||
|
||||
func wrap_handler(h web.Handler) http.HandlerFunc {
|
||||
switch handle := h.(type) {
|
||||
case http.HandlerFunc:
|
||||
return handle
|
||||
case handlerStd:
|
||||
return handle
|
||||
case handlerStdCtx:
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
handle(w, r, web.FromContext(r.Context()))
|
||||
handle(w, r, webCtx(w, r))
|
||||
}
|
||||
case handlerStdReqCtx:
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
handle(w, r, getReqCtx(r.Context()))
|
||||
handle(w, r, reqCtx(w, r))
|
||||
}
|
||||
case handlerReqCtx:
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
handle(getReqCtx(r.Context()))
|
||||
handle(reqCtx(w, r))
|
||||
}
|
||||
case handlerReqCtxRes:
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := getReqCtx(r.Context())
|
||||
ctx := reqCtx(w, r)
|
||||
res := handle(ctx)
|
||||
if res != nil {
|
||||
res.WriteTo(ctx)
|
||||
@ -48,15 +49,27 @@ func wrap_handler(h web.Handler) http.HandlerFunc {
|
||||
}
|
||||
case handlerCtx:
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
handle(web.FromContext(r.Context()))
|
||||
handle(webCtx(w, r))
|
||||
}
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("unexpected handler type: %T", h))
|
||||
}
|
||||
|
||||
func getReqCtx(ctx context.Context) *models.ReqContext {
|
||||
reqCtx, ok := ctx.Value(ctxkey.Key{}).(*models.ReqContext)
|
||||
func webCtx(w http.ResponseWriter, r *http.Request) *web.Context {
|
||||
ctx := web.FromContext(r.Context())
|
||||
if ctx == nil {
|
||||
panic("no *web.Context found")
|
||||
}
|
||||
|
||||
ctx.Req = r
|
||||
ctx.Resp = web.Rw(w, r)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func reqCtx(w http.ResponseWriter, r *http.Request) *models.ReqContext {
|
||||
wCtx := webCtx(w, r)
|
||||
reqCtx, ok := wCtx.Req.Context().Value(ctxkey.Key{}).(*models.ReqContext)
|
||||
if !ok {
|
||||
panic("no *models.ReqContext found")
|
||||
}
|
||||
|
@ -27,50 +27,52 @@ import (
|
||||
"github.com/grafana/grafana/pkg/web"
|
||||
)
|
||||
|
||||
func Logger(cfg *setting.Cfg) web.Handler {
|
||||
return func(res http.ResponseWriter, req *http.Request, c *web.Context) {
|
||||
start := time.Now()
|
||||
func Logger(cfg *setting.Cfg) web.Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
rw := res.(web.ResponseWriter)
|
||||
c.Next()
|
||||
rw := web.Rw(w, r)
|
||||
next.ServeHTTP(rw, r)
|
||||
|
||||
timeTaken := time.Since(start) / time.Millisecond
|
||||
duration := time.Since(start).String()
|
||||
ctx := contexthandler.FromContext(c.Req.Context())
|
||||
if ctx != nil && ctx.PerfmonTimer != nil {
|
||||
ctx.PerfmonTimer.Observe(float64(timeTaken))
|
||||
}
|
||||
|
||||
status := rw.Status()
|
||||
if status == 200 || status == 304 {
|
||||
if !cfg.RouterLogging {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if ctx != nil {
|
||||
logParams := []interface{}{
|
||||
"method", req.Method,
|
||||
"path", req.URL.Path,
|
||||
"status", status,
|
||||
"remote_addr", c.RemoteAddr(),
|
||||
"time_ms", int64(timeTaken),
|
||||
"duration", duration,
|
||||
"size", rw.Size(),
|
||||
"referer", SanitizeURL(ctx, req.Referer()),
|
||||
timeTaken := time.Since(start) / time.Millisecond
|
||||
duration := time.Since(start).String()
|
||||
ctx := contexthandler.FromContext(r.Context())
|
||||
if ctx != nil && ctx.PerfmonTimer != nil {
|
||||
ctx.PerfmonTimer.Observe(float64(timeTaken))
|
||||
}
|
||||
|
||||
traceID := tracing.TraceIDFromContext(ctx.Req.Context(), false)
|
||||
if traceID != "" {
|
||||
logParams = append(logParams, "traceID", traceID)
|
||||
status := rw.Status()
|
||||
if status == 200 || status == 304 {
|
||||
if !cfg.RouterLogging {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if status >= 500 {
|
||||
ctx.Logger.Error("Request Completed", logParams...)
|
||||
} else {
|
||||
ctx.Logger.Info("Request Completed", logParams...)
|
||||
if ctx != nil {
|
||||
logParams := []interface{}{
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", status,
|
||||
"remote_addr", ctx.RemoteAddr(),
|
||||
"time_ms", int64(timeTaken),
|
||||
"duration", duration,
|
||||
"size", rw.Size(),
|
||||
"referer", SanitizeURL(ctx, r.Referer()),
|
||||
}
|
||||
|
||||
traceID := tracing.TraceIDFromContext(ctx.Req.Context(), false)
|
||||
if traceID != "" {
|
||||
logParams = append(logParams, "traceID", traceID)
|
||||
}
|
||||
|
||||
if status >= 500 {
|
||||
ctx.Logger.Error("Request Completed", logParams...)
|
||||
} else {
|
||||
ctx.Logger.Info("Request Completed", logParams...)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -647,6 +647,9 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
|
||||
sc.context = c
|
||||
if sc.handlerFunc != nil {
|
||||
sc.handlerFunc(sc.context)
|
||||
if !c.Resp.Written() {
|
||||
c.Resp.WriteHeader(http.StatusOK)
|
||||
}
|
||||
} else {
|
||||
t.Log("Returning JSON OK")
|
||||
resp := make(map[string]interface{})
|
||||
|
@ -102,69 +102,73 @@ func function(pc uintptr) []byte {
|
||||
|
||||
// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
|
||||
// While Martini is in development mode, Recovery will also output the panic as HTML.
|
||||
func Recovery(cfg *setting.Cfg) web.Handler {
|
||||
return func(c *web.Context) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
var panicLogger log.Logger
|
||||
panicLogger = log.New("recovery")
|
||||
// try to get request logger
|
||||
ctx := contexthandler.FromContext(c.Req.Context())
|
||||
if ctx != nil {
|
||||
panicLogger = ctx.Logger
|
||||
}
|
||||
func Recovery(cfg *setting.Cfg) web.Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
c := web.FromContext(req.Context())
|
||||
|
||||
if err, ok := r.(error); ok {
|
||||
// http.ErrAbortHandler is suppressed by default in the http package
|
||||
// and used as a signal for aborting requests. Suppresses stacktrace
|
||||
// since it doesn't add any important information.
|
||||
if errors.Is(err, http.ErrAbortHandler) {
|
||||
panicLogger.Error("Request error", "error", err)
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
var panicLogger log.Logger
|
||||
panicLogger = log.New("recovery")
|
||||
// try to get request logger
|
||||
ctx := contexthandler.FromContext(c.Req.Context())
|
||||
if ctx != nil {
|
||||
panicLogger = ctx.Logger
|
||||
}
|
||||
|
||||
if err, ok := r.(error); ok {
|
||||
// http.ErrAbortHandler is suppressed by default in the http package
|
||||
// and used as a signal for aborting requests. Suppresses stacktrace
|
||||
// since it doesn't add any important information.
|
||||
if errors.Is(err, http.ErrAbortHandler) {
|
||||
panicLogger.Error("Request error", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
stack := stack(3)
|
||||
panicLogger.Error("Request error", "error", r, "stack", string(stack))
|
||||
|
||||
// if response has already been written, skip.
|
||||
if c.Resp.Written() {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
stack := stack(3)
|
||||
panicLogger.Error("Request error", "error", r, "stack", string(stack))
|
||||
data := struct {
|
||||
Title string
|
||||
AppTitle string
|
||||
AppSubUrl string
|
||||
Theme string
|
||||
ErrorMsg string
|
||||
}{"Server Error", "Grafana", cfg.AppSubURL, cfg.DefaultTheme, ""}
|
||||
|
||||
// if response has already been written, skip.
|
||||
if c.Resp.Written() {
|
||||
return
|
||||
}
|
||||
if setting.Env == setting.Dev {
|
||||
if err, ok := r.(error); ok {
|
||||
data.Title = err.Error()
|
||||
}
|
||||
|
||||
data := struct {
|
||||
Title string
|
||||
AppTitle string
|
||||
AppSubUrl string
|
||||
Theme string
|
||||
ErrorMsg string
|
||||
}{"Server Error", "Grafana", cfg.AppSubURL, cfg.DefaultTheme, ""}
|
||||
|
||||
if setting.Env == setting.Dev {
|
||||
if err, ok := r.(error); ok {
|
||||
data.Title = err.Error()
|
||||
data.ErrorMsg = string(stack)
|
||||
}
|
||||
|
||||
data.ErrorMsg = string(stack)
|
||||
}
|
||||
if ctx != nil && ctx.IsApiRequest() {
|
||||
resp := make(map[string]interface{})
|
||||
resp["message"] = "Internal Server Error - Check the Grafana server logs for the detailed error message."
|
||||
|
||||
if ctx != nil && ctx.IsApiRequest() {
|
||||
resp := make(map[string]interface{})
|
||||
resp["message"] = "Internal Server Error - Check the Grafana server logs for the detailed error message."
|
||||
if data.ErrorMsg != "" {
|
||||
resp["error"] = fmt.Sprintf("%v - %v", data.Title, data.ErrorMsg)
|
||||
} else {
|
||||
resp["error"] = data.Title
|
||||
}
|
||||
|
||||
if data.ErrorMsg != "" {
|
||||
resp["error"] = fmt.Sprintf("%v - %v", data.Title, data.ErrorMsg)
|
||||
ctx.JSON(500, resp)
|
||||
} else {
|
||||
resp["error"] = data.Title
|
||||
ctx.HTML(500, cfg.ErrTemplateName, data)
|
||||
}
|
||||
|
||||
c.JSON(500, resp)
|
||||
} else {
|
||||
c.HTML(500, cfg.ErrTemplateName, data)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}()
|
||||
|
||||
c.Next()
|
||||
next.ServeHTTP(w, req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -59,7 +59,7 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) {
|
||||
require.NoError(t, err)
|
||||
|
||||
sc.m = web.New()
|
||||
sc.m.Use(Recovery(cfg))
|
||||
sc.m.UseMiddleware(Recovery(cfg))
|
||||
|
||||
sc.m.Use(AddDefaultResponseHeaders(cfg))
|
||||
sc.m.UseMiddleware(web.Renderer(viewsPath, "[[", "]]"))
|
||||
|
@ -46,57 +46,60 @@ func init() {
|
||||
}
|
||||
|
||||
// RequestMetrics is a middleware handler that instruments the request.
|
||||
func RequestMetrics(features featuremgmt.FeatureToggles) web.Handler {
|
||||
func RequestMetrics(features featuremgmt.FeatureToggles) web.Middleware {
|
||||
log := log.New("middleware.request-metrics")
|
||||
|
||||
return func(res http.ResponseWriter, req *http.Request, c *web.Context) {
|
||||
rw := res.(web.ResponseWriter)
|
||||
now := time.Now()
|
||||
httpRequestsInFlight.Inc()
|
||||
defer httpRequestsInFlight.Dec()
|
||||
c.Next()
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
rw := web.Rw(w, r)
|
||||
now := time.Now()
|
||||
httpRequestsInFlight.Inc()
|
||||
defer httpRequestsInFlight.Dec()
|
||||
next.ServeHTTP(w, r)
|
||||
|
||||
status := rw.Status()
|
||||
code := sanitizeCode(status)
|
||||
status := rw.Status()
|
||||
code := sanitizeCode(status)
|
||||
|
||||
handler := "unknown"
|
||||
if routeOperation, exists := routeOperationName(c.Req); exists {
|
||||
handler = routeOperation
|
||||
} else {
|
||||
// if grafana does not recognize the handler and returns 404 we should register it as `notfound`
|
||||
if status == http.StatusNotFound {
|
||||
handler = "notfound"
|
||||
handler := "unknown"
|
||||
// TODO: do not depend on web.Context from the future
|
||||
if routeOperation, exists := routeOperationName(web.FromContext(r.Context()).Req); exists {
|
||||
handler = routeOperation
|
||||
} else {
|
||||
// log requests where we could not identify handler so we can register them.
|
||||
if features.IsEnabled(featuremgmt.FlagLogRequestsInstrumentedAsUnknown) {
|
||||
log.Warn("request instrumented as unknown", "path", c.Req.URL.Path, "status_code", status)
|
||||
// if grafana does not recognize the handler and returns 404 we should register it as `notfound`
|
||||
if status == http.StatusNotFound {
|
||||
handler = "notfound"
|
||||
} else {
|
||||
// log requests where we could not identify handler so we can register them.
|
||||
if features.IsEnabled(featuremgmt.FlagLogRequestsInstrumentedAsUnknown) {
|
||||
log.Warn("request instrumented as unknown", "path", r.URL.Path, "status_code", status)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// avoiding the sanitize functions for in the new instrumentation
|
||||
// since they dont make much sense. We should remove them later.
|
||||
histogram := httpRequestDurationHistogram.
|
||||
WithLabelValues(handler, code, req.Method)
|
||||
if traceID := tracing.TraceIDFromContext(c.Req.Context(), true); traceID != "" {
|
||||
// Need to type-convert the Observer to an
|
||||
// ExemplarObserver. This will always work for a
|
||||
// HistogramVec.
|
||||
histogram.(prometheus.ExemplarObserver).ObserveWithExemplar(
|
||||
time.Since(now).Seconds(), prometheus.Labels{"traceID": traceID},
|
||||
)
|
||||
return
|
||||
}
|
||||
histogram.Observe(time.Since(now).Seconds())
|
||||
// avoiding the sanitize functions for in the new instrumentation
|
||||
// since they dont make much sense. We should remove them later.
|
||||
histogram := httpRequestDurationHistogram.
|
||||
WithLabelValues(handler, code, r.Method)
|
||||
if traceID := tracing.TraceIDFromContext(r.Context(), true); traceID != "" {
|
||||
// Need to type-convert the Observer to an
|
||||
// ExemplarObserver. This will always work for a
|
||||
// HistogramVec.
|
||||
histogram.(prometheus.ExemplarObserver).ObserveWithExemplar(
|
||||
time.Since(now).Seconds(), prometheus.Labels{"traceID": traceID},
|
||||
)
|
||||
return
|
||||
}
|
||||
histogram.Observe(time.Since(now).Seconds())
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(req.RequestURI, "/api/datasources/proxy"):
|
||||
countProxyRequests(status)
|
||||
case strings.HasPrefix(req.RequestURI, "/api/"):
|
||||
countApiRequests(status)
|
||||
default:
|
||||
countPageRequests(status)
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(r.RequestURI, "/api/datasources/proxy"):
|
||||
countProxyRequests(status)
|
||||
case strings.HasPrefix(r.RequestURI, "/api/"):
|
||||
countApiRequests(status)
|
||||
default:
|
||||
countPageRequests(status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -61,37 +61,38 @@ func routeOperationName(req *http.Request) (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
func RequestTracing(tracer tracing.Tracer) web.Handler {
|
||||
return func(res http.ResponseWriter, req *http.Request, c *web.Context) {
|
||||
if strings.HasPrefix(c.Req.URL.Path, "/public/") ||
|
||||
c.Req.URL.Path == "/robots.txt" ||
|
||||
c.Req.URL.Path == "/favicon.ico" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
func RequestTracing(tracer tracing.Tracer) web.Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if strings.HasPrefix(req.URL.Path, "/public/") || req.URL.Path == "/robots.txt" || req.URL.Path == "/favicon.ico" {
|
||||
next.ServeHTTP(w, req)
|
||||
return
|
||||
}
|
||||
|
||||
rw := res.(web.ResponseWriter)
|
||||
rw := web.Rw(w, req)
|
||||
|
||||
wireContext := otel.GetTextMapPropagator().Extract(req.Context(), propagation.HeaderCarrier(req.Header))
|
||||
ctx, span := tracer.Start(req.Context(), fmt.Sprintf("HTTP %s %s", req.Method, req.URL.Path), trace.WithLinks(trace.LinkFromContext(wireContext)))
|
||||
wireContext := otel.GetTextMapPropagator().Extract(req.Context(), propagation.HeaderCarrier(req.Header))
|
||||
ctx, span := tracer.Start(req.Context(), fmt.Sprintf("HTTP %s %s", req.Method, req.URL.Path), trace.WithLinks(trace.LinkFromContext(wireContext)))
|
||||
|
||||
c.Req = req.WithContext(ctx)
|
||||
c.Next()
|
||||
req = req.WithContext(ctx)
|
||||
next.ServeHTTP(w, req)
|
||||
|
||||
// Only call span.Finish when a route operation name have been set,
|
||||
// meaning that not set the span would not be reported.
|
||||
if routeOperation, exists := routeOperationName(c.Req); exists {
|
||||
defer span.End()
|
||||
span.SetName(fmt.Sprintf("HTTP %s %s", req.Method, routeOperation))
|
||||
}
|
||||
// Only call span.Finish when a route operation name have been set,
|
||||
// meaning that not set the span would not be reported.
|
||||
// TODO: do not depend on web.Context from the future
|
||||
if routeOperation, exists := routeOperationName(web.FromContext(req.Context()).Req); exists {
|
||||
defer span.End()
|
||||
span.SetName(fmt.Sprintf("HTTP %s %s", req.Method, routeOperation))
|
||||
}
|
||||
|
||||
status := rw.Status()
|
||||
status := rw.Status()
|
||||
|
||||
span.SetAttributes("http.status_code", status, attribute.Int("http.status_code", status))
|
||||
span.SetAttributes("http.url", req.RequestURI, attribute.String("http.url", req.RequestURI))
|
||||
span.SetAttributes("http.method", req.Method, attribute.String("http.method", req.Method))
|
||||
if status >= 400 {
|
||||
span.SetStatus(codes.Error, fmt.Sprintf("error with HTTP status code %s", strconv.Itoa(status)))
|
||||
}
|
||||
span.SetAttributes("http.status_code", status, attribute.Int("http.status_code", status))
|
||||
span.SetAttributes("http.url", req.RequestURI, attribute.String("http.url", req.RequestURI))
|
||||
span.SetAttributes("http.method", req.Method, attribute.String("http.method", req.Method))
|
||||
if status >= 400 {
|
||||
span.SetStatus(codes.Error, fmt.Sprintf("error with HTTP status code %s", strconv.Itoa(status)))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -67,6 +67,7 @@ func TestMiddleware(t *testing.T) {
|
||||
endpointCalled := false
|
||||
server.Get("/", func(c *models.ReqContext) {
|
||||
endpointCalled = true
|
||||
c.Resp.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
request, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
|
@ -50,10 +50,10 @@ func getRequest(t *testing.T, url string, expStatusCode int) *http.Response {
|
||||
t.Helper()
|
||||
// nolint:gosec
|
||||
resp, err := http.Get(url)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, resp.Body.Close())
|
||||
})
|
||||
require.NoError(t, err)
|
||||
if expStatusCode != resp.StatusCode {
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
@ -29,8 +29,7 @@ import (
|
||||
// Context represents the runtime context of current request of Macaron instance.
|
||||
// It is the integration of most frequently used middlewares and helper methods.
|
||||
type Context struct {
|
||||
handlers []http.Handler
|
||||
index int
|
||||
mws []Middleware
|
||||
|
||||
*Router
|
||||
Req *http.Request
|
||||
@ -39,30 +38,20 @@ type Context struct {
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
func (ctx *Context) handler() http.Handler {
|
||||
if ctx.index < len(ctx.handlers) {
|
||||
return ctx.handlers[ctx.index]
|
||||
}
|
||||
if ctx.index == len(ctx.handlers) {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
}
|
||||
panic("invalid index for context handler")
|
||||
}
|
||||
|
||||
// Next runs the next handler in the context chain
|
||||
func (ctx *Context) Next() {
|
||||
ctx.index++
|
||||
ctx.run()
|
||||
}
|
||||
|
||||
func (ctx *Context) run() {
|
||||
for ctx.index <= len(ctx.handlers) {
|
||||
ctx.handler().ServeHTTP(ctx.Resp, ctx.Req)
|
||||
h := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
for i := len(ctx.mws) - 1; i >= 0; i-- {
|
||||
h = ctx.mws[i](h)
|
||||
}
|
||||
|
||||
ctx.index++
|
||||
if ctx.Resp.Written() {
|
||||
return
|
||||
}
|
||||
rw := ctx.Resp
|
||||
h.ServeHTTP(ctx.Resp, ctx.Req)
|
||||
|
||||
// Prevent the handler chain from not writing anything.
|
||||
// This indicates nearly always that a middleware is misbehaving and not calling its next.ServeHTTP().
|
||||
// In rare cases where a blank http.StatusOK without any body is wished, explicitly state that using w.WriteStatus(http.StatusOK)
|
||||
if !rw.Written() {
|
||||
panic("chain did not write HTTP response")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -53,28 +53,16 @@ type Handler interface{}
|
||||
//go:linkname hack_wrap github.com/grafana/grafana/pkg/api/response.wrap_handler
|
||||
func hack_wrap(Handler) http.HandlerFunc
|
||||
|
||||
// validateAndWrapHandler makes sure a handler is a callable function, it panics if not.
|
||||
// When the handler is also potential to be any built-in inject.FastInvoker,
|
||||
// it wraps the handler automatically to have some performance gain.
|
||||
func validateAndWrapHandler(h Handler) http.Handler {
|
||||
// wrapHandler turns any supported handler type into a http.Handler by wrapping it accordingly
|
||||
func wrapHandler(h Handler) http.Handler {
|
||||
return hack_wrap(h)
|
||||
}
|
||||
|
||||
// validateAndWrapHandlers preforms validation and wrapping for each input handler.
|
||||
// It accepts an optional wrapper function to perform custom wrapping on handlers.
|
||||
func validateAndWrapHandlers(handlers []Handler) []http.Handler {
|
||||
wrappedHandlers := make([]http.Handler, len(handlers))
|
||||
for i, h := range handlers {
|
||||
wrappedHandlers[i] = validateAndWrapHandler(h)
|
||||
}
|
||||
|
||||
return wrappedHandlers
|
||||
}
|
||||
|
||||
// Macaron represents the top level web application.
|
||||
// Injector methods can be invoked to map services on a global level.
|
||||
type Macaron struct {
|
||||
handlers []http.Handler
|
||||
// handlers []http.Handler
|
||||
mws []Middleware
|
||||
|
||||
urlPrefix string // For suburl support.
|
||||
*Router
|
||||
@ -119,43 +107,50 @@ func SetURLParams(r *http.Request, vars map[string]string) *http.Request {
|
||||
return r.WithContext(context.WithValue(r.Context(), paramsKey{}, vars))
|
||||
}
|
||||
|
||||
// UseMiddleware is a traditional approach to writing middleware in Go.
|
||||
// A middleware is a function that has a reference to the next handler in the chain
|
||||
// and returns the actual middleware handler, that may do its job and optionally
|
||||
// call next.
|
||||
// Due to how Macaron handles/injects requests and responses we patch the web.Context
|
||||
// to use the new ResponseWriter and http.Request here. The caller may only call
|
||||
// `next.ServeHTTP(rw, req)` to pass a modified response writer and/or a request to the
|
||||
// further middlewares in the chain.
|
||||
func (m *Macaron) UseMiddleware(middleware func(http.Handler) http.Handler) {
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
c := FromContext(req.Context())
|
||||
c.Req = req
|
||||
if mrw, ok := rw.(*responseWriter); ok {
|
||||
c.Resp = mrw
|
||||
} else {
|
||||
c.Resp = NewResponseWriter(req.Method, rw)
|
||||
}
|
||||
c.Next()
|
||||
})
|
||||
m.handlers = append(m.handlers, middleware(next))
|
||||
type Middleware = func(next http.Handler) http.Handler
|
||||
|
||||
// UseMiddleware registers the given Middleware
|
||||
func (m *Macaron) UseMiddleware(mw Middleware) {
|
||||
m.mws = append(m.mws, mw)
|
||||
}
|
||||
|
||||
// Use adds a middleware Handler to the stack,
|
||||
// and panics if the handler is not a callable func.
|
||||
// Middleware Handlers are invoked in the order that they are added.
|
||||
func (m *Macaron) Use(handler Handler) {
|
||||
h := validateAndWrapHandler(handler)
|
||||
m.handlers = append(m.handlers, h)
|
||||
// Use registers the provided Handler as a middleware.
|
||||
// The argument may be any supported handler or the Middleware type
|
||||
// Deprecated: use UseMiddleware instead
|
||||
func (m *Macaron) Use(h Handler) {
|
||||
m.mws = append(m.mws, mwFromHandler(h))
|
||||
}
|
||||
|
||||
func mwFromHandler(handler Handler) Middleware {
|
||||
if mw, ok := handler.(Middleware); ok {
|
||||
return mw
|
||||
}
|
||||
|
||||
h := wrapHandler(handler)
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mrw, ok := w.(*responseWriter)
|
||||
if !ok {
|
||||
mrw = NewResponseWriter(r.Method, w).(*responseWriter)
|
||||
}
|
||||
|
||||
h.ServeHTTP(mrw, r)
|
||||
if mrw.Written() {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context().Value(macaronContextKey{}).(*Context)
|
||||
next.ServeHTTP(ctx.Resp, ctx.Req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Macaron) createContext(rw http.ResponseWriter, req *http.Request) *Context {
|
||||
c := &Context{
|
||||
handlers: m.handlers,
|
||||
index: 0,
|
||||
Router: m.Router,
|
||||
Resp: NewResponseWriter(req.Method, rw),
|
||||
logger: log.New("macaron.context"),
|
||||
mws: m.mws,
|
||||
Router: m.Router,
|
||||
Resp: NewResponseWriter(req.Method, rw),
|
||||
logger: log.New("macaron.context"),
|
||||
}
|
||||
|
||||
c.Req = req.WithContext(context.WithValue(req.Context(), macaronContextKey{}, c))
|
||||
|
@ -26,7 +26,7 @@ import (
|
||||
// Renderer is a Middleware that injects a template renderer into the macaron context, enabling ctx.HTML calls in the handlers.
|
||||
// If MACARON_ENV is set to "development" then templates will be recompiled on every request. For more performance, set the
|
||||
// MACARON_ENV environment variable to "production".
|
||||
func Renderer(dir, leftDelim, rightDelim string) func(http.Handler) http.Handler {
|
||||
func Renderer(dir, leftDelim, rightDelim string) Middleware {
|
||||
fs := os.DirFS(dir)
|
||||
t, err := compileTemplates(fs, leftDelim, rightDelim)
|
||||
if err != nil {
|
||||
|
@ -46,6 +46,16 @@ func NewResponseWriter(method string, rw http.ResponseWriter) ResponseWriter {
|
||||
return &responseWriter{method, rw, 0, 0, nil}
|
||||
}
|
||||
|
||||
// Rw returns a ResponseWriter. If the argument already satisfies the interface,
|
||||
// it is returned as is, otherwise it is wrapped using NewResponseWriter
|
||||
func Rw(rw http.ResponseWriter, req *http.Request) ResponseWriter {
|
||||
if mrw, ok := rw.(ResponseWriter); ok {
|
||||
return mrw
|
||||
}
|
||||
|
||||
return NewResponseWriter(req.Method, rw)
|
||||
}
|
||||
|
||||
type responseWriter struct {
|
||||
method string
|
||||
http.ResponseWriter
|
||||
|
@ -146,13 +146,12 @@ func (r *Router) Handle(method string, pattern string, handlers []Handler) {
|
||||
h = append(h, handlers...)
|
||||
handlers = h
|
||||
}
|
||||
httpHandlers := validateAndWrapHandlers(handlers)
|
||||
|
||||
r.handle(method, pattern, func(resp http.ResponseWriter, req *http.Request, params map[string]string) {
|
||||
c := r.m.createContext(resp, SetURLParams(req, params))
|
||||
c.handlers = make([]http.Handler, 0, len(r.m.handlers)+len(handlers))
|
||||
c.handlers = append(c.handlers, r.m.handlers...)
|
||||
c.handlers = append(c.handlers, httpHandlers...)
|
||||
for _, h := range handlers {
|
||||
c.mws = append(c.mws, mwFromHandler(h))
|
||||
}
|
||||
c.run()
|
||||
})
|
||||
}
|
||||
@ -194,12 +193,11 @@ func (r *Router) Any(pattern string, h ...Handler) { r.Handle("*", pattern, h) }
|
||||
// found. If it is not set, http.NotFound is used.
|
||||
// Be sure to set 404 response code in your handler.
|
||||
func (r *Router) NotFound(handlers ...Handler) {
|
||||
httpHandlers := validateAndWrapHandlers(handlers)
|
||||
r.notFound = func(rw http.ResponseWriter, req *http.Request) {
|
||||
c := r.m.createContext(rw, req)
|
||||
c.handlers = make([]http.Handler, 0, len(r.m.handlers)+len(handlers))
|
||||
c.handlers = append(c.handlers, r.m.handlers...)
|
||||
c.handlers = append(c.handlers, httpHandlers...)
|
||||
for _, h := range handlers {
|
||||
c.mws = append(c.mws, mwFromHandler(h))
|
||||
}
|
||||
c.run()
|
||||
}
|
||||
}
|
||||
|
73
pkg/web/webtest/middleware.go
Normal file
73
pkg/web/webtest/middleware.go
Normal file
@ -0,0 +1,73 @@
|
||||
package webtest
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/grafana/grafana/pkg/web"
|
||||
)
|
||||
|
||||
type Context struct {
|
||||
Req *http.Request
|
||||
Rw http.ResponseWriter
|
||||
Next http.Handler
|
||||
}
|
||||
|
||||
// Middleware is a utility for testing middlewares
|
||||
type Middleware struct {
|
||||
// Before are run ahead of the returned context
|
||||
Before []web.Handler
|
||||
// After are part of the http.Handler chain
|
||||
After []web.Handler
|
||||
// The actual handler at the end of the chain
|
||||
Handler web.Handler
|
||||
}
|
||||
|
||||
// MiddlewareContext returns a *http.Request, http.ResponseWriter and http.Handler
|
||||
// exactly as if it was passed to a middleware
|
||||
func MiddlewareContext(test Middleware, req *http.Request) *Context {
|
||||
m := web.New()
|
||||
|
||||
// pkg/web requires the chain to write an HTTP response.
|
||||
// While this ensures a basic amount of correctness for real handler chains,
|
||||
// it is naturally incompatible with this package, as we terminate the chain early to pass its
|
||||
// state to the surrounding test.
|
||||
// By replacing the http.ResponseWriter and writing to the old one we make pkg/web happy.
|
||||
m.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
rw := web.Rw(httptest.NewRecorder(), r)
|
||||
next.ServeHTTP(rw, r)
|
||||
})
|
||||
})
|
||||
|
||||
for _, mw := range test.Before {
|
||||
m.Use(mw)
|
||||
}
|
||||
|
||||
ch := make(chan *Context)
|
||||
m.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ch <- &Context{
|
||||
Req: r,
|
||||
Rw: w,
|
||||
Next: next,
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
for _, mw := range test.After {
|
||||
m.Use(mw)
|
||||
}
|
||||
|
||||
// set the provided (or noop) handler to exactly the queried path
|
||||
handler := test.Handler
|
||||
if handler == nil {
|
||||
handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
}
|
||||
m.Handle(req.Method, req.URL.RequestURI(), []web.Handler{handler})
|
||||
go m.ServeHTTP(httptest.NewRecorder(), req)
|
||||
|
||||
return <-ch
|
||||
}
|
@ -37,7 +37,7 @@ func NewServer(t testing.TB, routeRegister routing.RouteRegister) *Server {
|
||||
c.Req = c.Req.WithContext(ctxkey.Set(c.Req.Context(), initCtx))
|
||||
})
|
||||
|
||||
m.Use(requestContextMiddleware())
|
||||
m.UseMiddleware(requestContextMiddleware())
|
||||
|
||||
routeRegister.Register(m.Router)
|
||||
testServer := httptest.NewServer(m)
|
||||
@ -126,24 +126,25 @@ func requestContextFromRequest(req *http.Request) *models.ReqContext {
|
||||
return val
|
||||
}
|
||||
|
||||
func requestContextMiddleware() web.Handler {
|
||||
return func(res http.ResponseWriter, req *http.Request) {
|
||||
c := ctxkey.Get(req.Context()).(*models.ReqContext)
|
||||
func requestContextMiddleware() web.Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c := ctxkey.Get(r.Context()).(*models.ReqContext)
|
||||
|
||||
ctx := requestContextFromRequest(req)
|
||||
if ctx == nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
ctx := requestContextFromRequest(r)
|
||||
if ctx != nil {
|
||||
c.SignedInUser = ctx.SignedInUser
|
||||
c.UserToken = ctx.UserToken
|
||||
c.IsSignedIn = ctx.IsSignedIn
|
||||
c.IsRenderCall = ctx.IsRenderCall
|
||||
c.AllowAnonymous = ctx.AllowAnonymous
|
||||
c.SkipCache = ctx.SkipCache
|
||||
c.RequestNonce = ctx.RequestNonce
|
||||
c.PerfmonTimer = ctx.PerfmonTimer
|
||||
c.LookupTokenErr = ctx.LookupTokenErr
|
||||
}
|
||||
|
||||
c.SignedInUser = ctx.SignedInUser
|
||||
c.UserToken = ctx.UserToken
|
||||
c.IsSignedIn = ctx.IsSignedIn
|
||||
c.IsRenderCall = ctx.IsRenderCall
|
||||
c.AllowAnonymous = ctx.AllowAnonymous
|
||||
c.SkipCache = ctx.SkipCache
|
||||
c.RequestNonce = ctx.RequestNonce
|
||||
c.PerfmonTimer = ctx.PerfmonTimer
|
||||
c.LookupTokenErr = ctx.LookupTokenErr
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user