From 534ece064be2564ac956ee2a528de42c031aec29 Mon Sep 17 00:00:00 2001 From: sh0rez Date: Tue, 9 Aug 2022 14:58:50 +0200 Subject: [PATCH] pkg/web: closure-style middlewares (#51238) * pkg/web: closure-style middlewares Switches the middleware execution model from web.Handlers in a slice to web.Middleware. Middlewares are temporarily kept in a slice to preserve ordering, but prior to execution they are applied, forming a giant call-stack, giving granular control over the execution flow. * pkg/middleware: adapt to web.Middleware * pkg/middleware/recovery: use c.Req over req c.Req gets updated by future handlers, while req stays static. The current recovery implementation needs this newer information * pkg/web: correct middleware ordering * pkg/webtest: adapt middleware * pkg/web/hack: set w and r onto web.Context By adopting std middlewares, it may happen they invoke next(w,r) without putting their modified w,r into the web.Context, leading old-style handlers to operate on outdated fields. pkg/web now takes care of this * pkg/middleware: selectively use future context * pkg/web: accept closure-style on Use() * webtest: Middleware testing adds a utility function to web/webtest to obtain a http.ResponseWriter, http.Request and http.Handler the same as a middleware that runs would receive * *: cleanup * pkg/web: don't wrap Middleware from Router * pkg/web: require chain to write response * *: remove temp files * webtest: don't require chain write * *: cleanup --- pkg/api/http_server.go | 2 +- pkg/api/response/web_hack.go | 29 +++-- pkg/middleware/logger.go | 76 ++++++------- pkg/middleware/middleware_test.go | 3 + pkg/middleware/recovery.go | 104 +++++++++--------- pkg/middleware/recovery_test.go | 2 +- pkg/middleware/request_metrics.go | 87 ++++++++------- pkg/middleware/request_tracing.go | 53 ++++----- pkg/services/accesscontrol/middleware_test.go | 1 + pkg/tests/api/alerting/testing.go | 2 +- pkg/web/context.go | 37 +++---- pkg/web/macaron.go | 89 +++++++-------- pkg/web/render.go | 2 +- pkg/web/response_writer.go | 10 ++ pkg/web/router.go | 14 +-- pkg/web/webtest/middleware.go | 73 ++++++++++++ pkg/web/webtest/webtest.go | 37 ++++--- 17 files changed, 357 insertions(+), 264 deletions(-) create mode 100644 pkg/web/webtest/middleware.go diff --git a/pkg/api/http_server.go b/pkg/api/http_server.go index e1014e1149c..f2ac79c0edc 100644 --- a/pkg/api/http_server.go +++ b/pkg/api/http_server.go @@ -526,7 +526,7 @@ func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() { m.UseMiddleware(middleware.Gziper()) } - m.Use(middleware.Recovery(hs.Cfg)) + m.UseMiddleware(middleware.Recovery(hs.Cfg)) m.UseMiddleware(hs.Csrf.Middleware()) hs.mapStatic(m, hs.Cfg.StaticRootPath, "build", "public/build") diff --git a/pkg/api/response/web_hack.go b/pkg/api/response/web_hack.go index 0cb1ccdcc9a..81f3bfa4ac5 100644 --- a/pkg/api/response/web_hack.go +++ b/pkg/api/response/web_hack.go @@ -4,7 +4,6 @@ package response //NOTE: This file belongs into pkg/web, but due to cyclic imports that are hard to resolve at the current time, it temporarily lives here. import ( - "context" "fmt" "net/http" @@ -24,23 +23,25 @@ type ( func wrap_handler(h web.Handler) http.HandlerFunc { switch handle := h.(type) { + case http.HandlerFunc: + return handle case handlerStd: return handle case handlerStdCtx: return func(w http.ResponseWriter, r *http.Request) { - handle(w, r, web.FromContext(r.Context())) + handle(w, r, webCtx(w, r)) } case handlerStdReqCtx: return func(w http.ResponseWriter, r *http.Request) { - handle(w, r, getReqCtx(r.Context())) + handle(w, r, reqCtx(w, r)) } case handlerReqCtx: return func(w http.ResponseWriter, r *http.Request) { - handle(getReqCtx(r.Context())) + handle(reqCtx(w, r)) } case handlerReqCtxRes: return func(w http.ResponseWriter, r *http.Request) { - ctx := getReqCtx(r.Context()) + ctx := reqCtx(w, r) res := handle(ctx) if res != nil { res.WriteTo(ctx) @@ -48,15 +49,27 @@ func wrap_handler(h web.Handler) http.HandlerFunc { } case handlerCtx: return func(w http.ResponseWriter, r *http.Request) { - handle(web.FromContext(r.Context())) + handle(webCtx(w, r)) } } panic(fmt.Sprintf("unexpected handler type: %T", h)) } -func getReqCtx(ctx context.Context) *models.ReqContext { - reqCtx, ok := ctx.Value(ctxkey.Key{}).(*models.ReqContext) +func webCtx(w http.ResponseWriter, r *http.Request) *web.Context { + ctx := web.FromContext(r.Context()) + if ctx == nil { + panic("no *web.Context found") + } + + ctx.Req = r + ctx.Resp = web.Rw(w, r) + return ctx +} + +func reqCtx(w http.ResponseWriter, r *http.Request) *models.ReqContext { + wCtx := webCtx(w, r) + reqCtx, ok := wCtx.Req.Context().Value(ctxkey.Key{}).(*models.ReqContext) if !ok { panic("no *models.ReqContext found") } diff --git a/pkg/middleware/logger.go b/pkg/middleware/logger.go index b5f75fe59ef..67bbf95ce09 100644 --- a/pkg/middleware/logger.go +++ b/pkg/middleware/logger.go @@ -27,50 +27,52 @@ import ( "github.com/grafana/grafana/pkg/web" ) -func Logger(cfg *setting.Cfg) web.Handler { - return func(res http.ResponseWriter, req *http.Request, c *web.Context) { - start := time.Now() +func Logger(cfg *setting.Cfg) web.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() - rw := res.(web.ResponseWriter) - c.Next() + rw := web.Rw(w, r) + next.ServeHTTP(rw, r) - timeTaken := time.Since(start) / time.Millisecond - duration := time.Since(start).String() - ctx := contexthandler.FromContext(c.Req.Context()) - if ctx != nil && ctx.PerfmonTimer != nil { - ctx.PerfmonTimer.Observe(float64(timeTaken)) - } - - status := rw.Status() - if status == 200 || status == 304 { - if !cfg.RouterLogging { - return - } - } - - if ctx != nil { - logParams := []interface{}{ - "method", req.Method, - "path", req.URL.Path, - "status", status, - "remote_addr", c.RemoteAddr(), - "time_ms", int64(timeTaken), - "duration", duration, - "size", rw.Size(), - "referer", SanitizeURL(ctx, req.Referer()), + timeTaken := time.Since(start) / time.Millisecond + duration := time.Since(start).String() + ctx := contexthandler.FromContext(r.Context()) + if ctx != nil && ctx.PerfmonTimer != nil { + ctx.PerfmonTimer.Observe(float64(timeTaken)) } - traceID := tracing.TraceIDFromContext(ctx.Req.Context(), false) - if traceID != "" { - logParams = append(logParams, "traceID", traceID) + status := rw.Status() + if status == 200 || status == 304 { + if !cfg.RouterLogging { + return + } } - if status >= 500 { - ctx.Logger.Error("Request Completed", logParams...) - } else { - ctx.Logger.Info("Request Completed", logParams...) + if ctx != nil { + logParams := []interface{}{ + "method", r.Method, + "path", r.URL.Path, + "status", status, + "remote_addr", ctx.RemoteAddr(), + "time_ms", int64(timeTaken), + "duration", duration, + "size", rw.Size(), + "referer", SanitizeURL(ctx, r.Referer()), + } + + traceID := tracing.TraceIDFromContext(ctx.Req.Context(), false) + if traceID != "" { + logParams = append(logParams, "traceID", traceID) + } + + if status >= 500 { + ctx.Logger.Error("Request Completed", logParams...) + } else { + ctx.Logger.Info("Request Completed", logParams...) + } } - } + }) } } diff --git a/pkg/middleware/middleware_test.go b/pkg/middleware/middleware_test.go index 5d4a5c4fb74..77641c717cb 100644 --- a/pkg/middleware/middleware_test.go +++ b/pkg/middleware/middleware_test.go @@ -647,6 +647,9 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func( sc.context = c if sc.handlerFunc != nil { sc.handlerFunc(sc.context) + if !c.Resp.Written() { + c.Resp.WriteHeader(http.StatusOK) + } } else { t.Log("Returning JSON OK") resp := make(map[string]interface{}) diff --git a/pkg/middleware/recovery.go b/pkg/middleware/recovery.go index fe5d3a22597..20b317b4e9d 100644 --- a/pkg/middleware/recovery.go +++ b/pkg/middleware/recovery.go @@ -102,69 +102,73 @@ func function(pc uintptr) []byte { // Recovery returns a middleware that recovers from any panics and writes a 500 if there was one. // While Martini is in development mode, Recovery will also output the panic as HTML. -func Recovery(cfg *setting.Cfg) web.Handler { - return func(c *web.Context) { - defer func() { - if r := recover(); r != nil { - var panicLogger log.Logger - panicLogger = log.New("recovery") - // try to get request logger - ctx := contexthandler.FromContext(c.Req.Context()) - if ctx != nil { - panicLogger = ctx.Logger - } +func Recovery(cfg *setting.Cfg) web.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + c := web.FromContext(req.Context()) - if err, ok := r.(error); ok { - // http.ErrAbortHandler is suppressed by default in the http package - // and used as a signal for aborting requests. Suppresses stacktrace - // since it doesn't add any important information. - if errors.Is(err, http.ErrAbortHandler) { - panicLogger.Error("Request error", "error", err) + defer func() { + if r := recover(); r != nil { + var panicLogger log.Logger + panicLogger = log.New("recovery") + // try to get request logger + ctx := contexthandler.FromContext(c.Req.Context()) + if ctx != nil { + panicLogger = ctx.Logger + } + + if err, ok := r.(error); ok { + // http.ErrAbortHandler is suppressed by default in the http package + // and used as a signal for aborting requests. Suppresses stacktrace + // since it doesn't add any important information. + if errors.Is(err, http.ErrAbortHandler) { + panicLogger.Error("Request error", "error", err) + return + } + } + + stack := stack(3) + panicLogger.Error("Request error", "error", r, "stack", string(stack)) + + // if response has already been written, skip. + if c.Resp.Written() { return } - } - stack := stack(3) - panicLogger.Error("Request error", "error", r, "stack", string(stack)) + data := struct { + Title string + AppTitle string + AppSubUrl string + Theme string + ErrorMsg string + }{"Server Error", "Grafana", cfg.AppSubURL, cfg.DefaultTheme, ""} - // if response has already been written, skip. - if c.Resp.Written() { - return - } + if setting.Env == setting.Dev { + if err, ok := r.(error); ok { + data.Title = err.Error() + } - data := struct { - Title string - AppTitle string - AppSubUrl string - Theme string - ErrorMsg string - }{"Server Error", "Grafana", cfg.AppSubURL, cfg.DefaultTheme, ""} - - if setting.Env == setting.Dev { - if err, ok := r.(error); ok { - data.Title = err.Error() + data.ErrorMsg = string(stack) } - data.ErrorMsg = string(stack) - } + if ctx != nil && ctx.IsApiRequest() { + resp := make(map[string]interface{}) + resp["message"] = "Internal Server Error - Check the Grafana server logs for the detailed error message." - if ctx != nil && ctx.IsApiRequest() { - resp := make(map[string]interface{}) - resp["message"] = "Internal Server Error - Check the Grafana server logs for the detailed error message." + if data.ErrorMsg != "" { + resp["error"] = fmt.Sprintf("%v - %v", data.Title, data.ErrorMsg) + } else { + resp["error"] = data.Title + } - if data.ErrorMsg != "" { - resp["error"] = fmt.Sprintf("%v - %v", data.Title, data.ErrorMsg) + ctx.JSON(500, resp) } else { - resp["error"] = data.Title + ctx.HTML(500, cfg.ErrTemplateName, data) } - - c.JSON(500, resp) - } else { - c.HTML(500, cfg.ErrTemplateName, data) } - } - }() + }() - c.Next() + next.ServeHTTP(w, req) + }) } } diff --git a/pkg/middleware/recovery_test.go b/pkg/middleware/recovery_test.go index 85bbef0246b..abf5b2510ff 100644 --- a/pkg/middleware/recovery_test.go +++ b/pkg/middleware/recovery_test.go @@ -59,7 +59,7 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) { require.NoError(t, err) sc.m = web.New() - sc.m.Use(Recovery(cfg)) + sc.m.UseMiddleware(Recovery(cfg)) sc.m.Use(AddDefaultResponseHeaders(cfg)) sc.m.UseMiddleware(web.Renderer(viewsPath, "[[", "]]")) diff --git a/pkg/middleware/request_metrics.go b/pkg/middleware/request_metrics.go index 4624f8f6ab2..9706968fd2e 100644 --- a/pkg/middleware/request_metrics.go +++ b/pkg/middleware/request_metrics.go @@ -46,57 +46,60 @@ func init() { } // RequestMetrics is a middleware handler that instruments the request. -func RequestMetrics(features featuremgmt.FeatureToggles) web.Handler { +func RequestMetrics(features featuremgmt.FeatureToggles) web.Middleware { log := log.New("middleware.request-metrics") - return func(res http.ResponseWriter, req *http.Request, c *web.Context) { - rw := res.(web.ResponseWriter) - now := time.Now() - httpRequestsInFlight.Inc() - defer httpRequestsInFlight.Dec() - c.Next() + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rw := web.Rw(w, r) + now := time.Now() + httpRequestsInFlight.Inc() + defer httpRequestsInFlight.Dec() + next.ServeHTTP(w, r) - status := rw.Status() - code := sanitizeCode(status) + status := rw.Status() + code := sanitizeCode(status) - handler := "unknown" - if routeOperation, exists := routeOperationName(c.Req); exists { - handler = routeOperation - } else { - // if grafana does not recognize the handler and returns 404 we should register it as `notfound` - if status == http.StatusNotFound { - handler = "notfound" + handler := "unknown" + // TODO: do not depend on web.Context from the future + if routeOperation, exists := routeOperationName(web.FromContext(r.Context()).Req); exists { + handler = routeOperation } else { - // log requests where we could not identify handler so we can register them. - if features.IsEnabled(featuremgmt.FlagLogRequestsInstrumentedAsUnknown) { - log.Warn("request instrumented as unknown", "path", c.Req.URL.Path, "status_code", status) + // if grafana does not recognize the handler and returns 404 we should register it as `notfound` + if status == http.StatusNotFound { + handler = "notfound" + } else { + // log requests where we could not identify handler so we can register them. + if features.IsEnabled(featuremgmt.FlagLogRequestsInstrumentedAsUnknown) { + log.Warn("request instrumented as unknown", "path", r.URL.Path, "status_code", status) + } } } - } - // avoiding the sanitize functions for in the new instrumentation - // since they dont make much sense. We should remove them later. - histogram := httpRequestDurationHistogram. - WithLabelValues(handler, code, req.Method) - if traceID := tracing.TraceIDFromContext(c.Req.Context(), true); traceID != "" { - // Need to type-convert the Observer to an - // ExemplarObserver. This will always work for a - // HistogramVec. - histogram.(prometheus.ExemplarObserver).ObserveWithExemplar( - time.Since(now).Seconds(), prometheus.Labels{"traceID": traceID}, - ) - return - } - histogram.Observe(time.Since(now).Seconds()) + // avoiding the sanitize functions for in the new instrumentation + // since they dont make much sense. We should remove them later. + histogram := httpRequestDurationHistogram. + WithLabelValues(handler, code, r.Method) + if traceID := tracing.TraceIDFromContext(r.Context(), true); traceID != "" { + // Need to type-convert the Observer to an + // ExemplarObserver. This will always work for a + // HistogramVec. + histogram.(prometheus.ExemplarObserver).ObserveWithExemplar( + time.Since(now).Seconds(), prometheus.Labels{"traceID": traceID}, + ) + return + } + histogram.Observe(time.Since(now).Seconds()) - switch { - case strings.HasPrefix(req.RequestURI, "/api/datasources/proxy"): - countProxyRequests(status) - case strings.HasPrefix(req.RequestURI, "/api/"): - countApiRequests(status) - default: - countPageRequests(status) - } + switch { + case strings.HasPrefix(r.RequestURI, "/api/datasources/proxy"): + countProxyRequests(status) + case strings.HasPrefix(r.RequestURI, "/api/"): + countApiRequests(status) + default: + countPageRequests(status) + } + }) } } diff --git a/pkg/middleware/request_tracing.go b/pkg/middleware/request_tracing.go index dcfce62fabe..c03e10a2b1d 100644 --- a/pkg/middleware/request_tracing.go +++ b/pkg/middleware/request_tracing.go @@ -61,37 +61,38 @@ func routeOperationName(req *http.Request) (string, bool) { return "", false } -func RequestTracing(tracer tracing.Tracer) web.Handler { - return func(res http.ResponseWriter, req *http.Request, c *web.Context) { - if strings.HasPrefix(c.Req.URL.Path, "/public/") || - c.Req.URL.Path == "/robots.txt" || - c.Req.URL.Path == "/favicon.ico" { - c.Next() - return - } +func RequestTracing(tracer tracing.Tracer) web.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if strings.HasPrefix(req.URL.Path, "/public/") || req.URL.Path == "/robots.txt" || req.URL.Path == "/favicon.ico" { + next.ServeHTTP(w, req) + return + } - rw := res.(web.ResponseWriter) + rw := web.Rw(w, req) - wireContext := otel.GetTextMapPropagator().Extract(req.Context(), propagation.HeaderCarrier(req.Header)) - ctx, span := tracer.Start(req.Context(), fmt.Sprintf("HTTP %s %s", req.Method, req.URL.Path), trace.WithLinks(trace.LinkFromContext(wireContext))) + wireContext := otel.GetTextMapPropagator().Extract(req.Context(), propagation.HeaderCarrier(req.Header)) + ctx, span := tracer.Start(req.Context(), fmt.Sprintf("HTTP %s %s", req.Method, req.URL.Path), trace.WithLinks(trace.LinkFromContext(wireContext))) - c.Req = req.WithContext(ctx) - c.Next() + req = req.WithContext(ctx) + next.ServeHTTP(w, req) - // Only call span.Finish when a route operation name have been set, - // meaning that not set the span would not be reported. - if routeOperation, exists := routeOperationName(c.Req); exists { - defer span.End() - span.SetName(fmt.Sprintf("HTTP %s %s", req.Method, routeOperation)) - } + // Only call span.Finish when a route operation name have been set, + // meaning that not set the span would not be reported. + // TODO: do not depend on web.Context from the future + if routeOperation, exists := routeOperationName(web.FromContext(req.Context()).Req); exists { + defer span.End() + span.SetName(fmt.Sprintf("HTTP %s %s", req.Method, routeOperation)) + } - status := rw.Status() + status := rw.Status() - span.SetAttributes("http.status_code", status, attribute.Int("http.status_code", status)) - span.SetAttributes("http.url", req.RequestURI, attribute.String("http.url", req.RequestURI)) - span.SetAttributes("http.method", req.Method, attribute.String("http.method", req.Method)) - if status >= 400 { - span.SetStatus(codes.Error, fmt.Sprintf("error with HTTP status code %s", strconv.Itoa(status))) - } + span.SetAttributes("http.status_code", status, attribute.Int("http.status_code", status)) + span.SetAttributes("http.url", req.RequestURI, attribute.String("http.url", req.RequestURI)) + span.SetAttributes("http.method", req.Method, attribute.String("http.method", req.Method)) + if status >= 400 { + span.SetStatus(codes.Error, fmt.Sprintf("error with HTTP status code %s", strconv.Itoa(status))) + } + }) } } diff --git a/pkg/services/accesscontrol/middleware_test.go b/pkg/services/accesscontrol/middleware_test.go index 0f10974bc5f..013e3d98876 100644 --- a/pkg/services/accesscontrol/middleware_test.go +++ b/pkg/services/accesscontrol/middleware_test.go @@ -67,6 +67,7 @@ func TestMiddleware(t *testing.T) { endpointCalled := false server.Get("/", func(c *models.ReqContext) { endpointCalled = true + c.Resp.WriteHeader(http.StatusOK) }) request, err := http.NewRequest(http.MethodGet, "/", nil) diff --git a/pkg/tests/api/alerting/testing.go b/pkg/tests/api/alerting/testing.go index a9e1883736b..3fc58e6b854 100644 --- a/pkg/tests/api/alerting/testing.go +++ b/pkg/tests/api/alerting/testing.go @@ -50,10 +50,10 @@ func getRequest(t *testing.T, url string, expStatusCode int) *http.Response { t.Helper() // nolint:gosec resp, err := http.Get(url) + require.NoError(t, err) t.Cleanup(func() { require.NoError(t, resp.Body.Close()) }) - require.NoError(t, err) if expStatusCode != resp.StatusCode { b, err := ioutil.ReadAll(resp.Body) require.NoError(t, err) diff --git a/pkg/web/context.go b/pkg/web/context.go index 06095e70aa6..59e47b0cd80 100644 --- a/pkg/web/context.go +++ b/pkg/web/context.go @@ -29,8 +29,7 @@ import ( // Context represents the runtime context of current request of Macaron instance. // It is the integration of most frequently used middlewares and helper methods. type Context struct { - handlers []http.Handler - index int + mws []Middleware *Router Req *http.Request @@ -39,30 +38,20 @@ type Context struct { logger log.Logger } -func (ctx *Context) handler() http.Handler { - if ctx.index < len(ctx.handlers) { - return ctx.handlers[ctx.index] - } - if ctx.index == len(ctx.handlers) { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - } - panic("invalid index for context handler") -} - -// Next runs the next handler in the context chain -func (ctx *Context) Next() { - ctx.index++ - ctx.run() -} - func (ctx *Context) run() { - for ctx.index <= len(ctx.handlers) { - ctx.handler().ServeHTTP(ctx.Resp, ctx.Req) + h := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + for i := len(ctx.mws) - 1; i >= 0; i-- { + h = ctx.mws[i](h) + } - ctx.index++ - if ctx.Resp.Written() { - return - } + rw := ctx.Resp + h.ServeHTTP(ctx.Resp, ctx.Req) + + // Prevent the handler chain from not writing anything. + // This indicates nearly always that a middleware is misbehaving and not calling its next.ServeHTTP(). + // In rare cases where a blank http.StatusOK without any body is wished, explicitly state that using w.WriteStatus(http.StatusOK) + if !rw.Written() { + panic("chain did not write HTTP response") } } diff --git a/pkg/web/macaron.go b/pkg/web/macaron.go index edec73da060..76fa0c1dd5d 100644 --- a/pkg/web/macaron.go +++ b/pkg/web/macaron.go @@ -53,28 +53,16 @@ type Handler interface{} //go:linkname hack_wrap github.com/grafana/grafana/pkg/api/response.wrap_handler func hack_wrap(Handler) http.HandlerFunc -// validateAndWrapHandler makes sure a handler is a callable function, it panics if not. -// When the handler is also potential to be any built-in inject.FastInvoker, -// it wraps the handler automatically to have some performance gain. -func validateAndWrapHandler(h Handler) http.Handler { +// wrapHandler turns any supported handler type into a http.Handler by wrapping it accordingly +func wrapHandler(h Handler) http.Handler { return hack_wrap(h) } -// validateAndWrapHandlers preforms validation and wrapping for each input handler. -// It accepts an optional wrapper function to perform custom wrapping on handlers. -func validateAndWrapHandlers(handlers []Handler) []http.Handler { - wrappedHandlers := make([]http.Handler, len(handlers)) - for i, h := range handlers { - wrappedHandlers[i] = validateAndWrapHandler(h) - } - - return wrappedHandlers -} - // Macaron represents the top level web application. // Injector methods can be invoked to map services on a global level. type Macaron struct { - handlers []http.Handler + // handlers []http.Handler + mws []Middleware urlPrefix string // For suburl support. *Router @@ -119,43 +107,50 @@ func SetURLParams(r *http.Request, vars map[string]string) *http.Request { return r.WithContext(context.WithValue(r.Context(), paramsKey{}, vars)) } -// UseMiddleware is a traditional approach to writing middleware in Go. -// A middleware is a function that has a reference to the next handler in the chain -// and returns the actual middleware handler, that may do its job and optionally -// call next. -// Due to how Macaron handles/injects requests and responses we patch the web.Context -// to use the new ResponseWriter and http.Request here. The caller may only call -// `next.ServeHTTP(rw, req)` to pass a modified response writer and/or a request to the -// further middlewares in the chain. -func (m *Macaron) UseMiddleware(middleware func(http.Handler) http.Handler) { - next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - c := FromContext(req.Context()) - c.Req = req - if mrw, ok := rw.(*responseWriter); ok { - c.Resp = mrw - } else { - c.Resp = NewResponseWriter(req.Method, rw) - } - c.Next() - }) - m.handlers = append(m.handlers, middleware(next)) +type Middleware = func(next http.Handler) http.Handler + +// UseMiddleware registers the given Middleware +func (m *Macaron) UseMiddleware(mw Middleware) { + m.mws = append(m.mws, mw) } -// Use adds a middleware Handler to the stack, -// and panics if the handler is not a callable func. -// Middleware Handlers are invoked in the order that they are added. -func (m *Macaron) Use(handler Handler) { - h := validateAndWrapHandler(handler) - m.handlers = append(m.handlers, h) +// Use registers the provided Handler as a middleware. +// The argument may be any supported handler or the Middleware type +// Deprecated: use UseMiddleware instead +func (m *Macaron) Use(h Handler) { + m.mws = append(m.mws, mwFromHandler(h)) +} + +func mwFromHandler(handler Handler) Middleware { + if mw, ok := handler.(Middleware); ok { + return mw + } + + h := wrapHandler(handler) + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mrw, ok := w.(*responseWriter) + if !ok { + mrw = NewResponseWriter(r.Method, w).(*responseWriter) + } + + h.ServeHTTP(mrw, r) + if mrw.Written() { + return + } + + ctx := r.Context().Value(macaronContextKey{}).(*Context) + next.ServeHTTP(ctx.Resp, ctx.Req) + }) + } } func (m *Macaron) createContext(rw http.ResponseWriter, req *http.Request) *Context { c := &Context{ - handlers: m.handlers, - index: 0, - Router: m.Router, - Resp: NewResponseWriter(req.Method, rw), - logger: log.New("macaron.context"), + mws: m.mws, + Router: m.Router, + Resp: NewResponseWriter(req.Method, rw), + logger: log.New("macaron.context"), } c.Req = req.WithContext(context.WithValue(req.Context(), macaronContextKey{}, c)) diff --git a/pkg/web/render.go b/pkg/web/render.go index 90bd86e74c9..75d8f1e17b6 100644 --- a/pkg/web/render.go +++ b/pkg/web/render.go @@ -26,7 +26,7 @@ import ( // Renderer is a Middleware that injects a template renderer into the macaron context, enabling ctx.HTML calls in the handlers. // If MACARON_ENV is set to "development" then templates will be recompiled on every request. For more performance, set the // MACARON_ENV environment variable to "production". -func Renderer(dir, leftDelim, rightDelim string) func(http.Handler) http.Handler { +func Renderer(dir, leftDelim, rightDelim string) Middleware { fs := os.DirFS(dir) t, err := compileTemplates(fs, leftDelim, rightDelim) if err != nil { diff --git a/pkg/web/response_writer.go b/pkg/web/response_writer.go index 8a11b65fbe1..62d90642419 100644 --- a/pkg/web/response_writer.go +++ b/pkg/web/response_writer.go @@ -46,6 +46,16 @@ func NewResponseWriter(method string, rw http.ResponseWriter) ResponseWriter { return &responseWriter{method, rw, 0, 0, nil} } +// Rw returns a ResponseWriter. If the argument already satisfies the interface, +// it is returned as is, otherwise it is wrapped using NewResponseWriter +func Rw(rw http.ResponseWriter, req *http.Request) ResponseWriter { + if mrw, ok := rw.(ResponseWriter); ok { + return mrw + } + + return NewResponseWriter(req.Method, rw) +} + type responseWriter struct { method string http.ResponseWriter diff --git a/pkg/web/router.go b/pkg/web/router.go index 3f32d428c29..32a16fd2031 100644 --- a/pkg/web/router.go +++ b/pkg/web/router.go @@ -146,13 +146,12 @@ func (r *Router) Handle(method string, pattern string, handlers []Handler) { h = append(h, handlers...) handlers = h } - httpHandlers := validateAndWrapHandlers(handlers) r.handle(method, pattern, func(resp http.ResponseWriter, req *http.Request, params map[string]string) { c := r.m.createContext(resp, SetURLParams(req, params)) - c.handlers = make([]http.Handler, 0, len(r.m.handlers)+len(handlers)) - c.handlers = append(c.handlers, r.m.handlers...) - c.handlers = append(c.handlers, httpHandlers...) + for _, h := range handlers { + c.mws = append(c.mws, mwFromHandler(h)) + } c.run() }) } @@ -194,12 +193,11 @@ func (r *Router) Any(pattern string, h ...Handler) { r.Handle("*", pattern, h) } // found. If it is not set, http.NotFound is used. // Be sure to set 404 response code in your handler. func (r *Router) NotFound(handlers ...Handler) { - httpHandlers := validateAndWrapHandlers(handlers) r.notFound = func(rw http.ResponseWriter, req *http.Request) { c := r.m.createContext(rw, req) - c.handlers = make([]http.Handler, 0, len(r.m.handlers)+len(handlers)) - c.handlers = append(c.handlers, r.m.handlers...) - c.handlers = append(c.handlers, httpHandlers...) + for _, h := range handlers { + c.mws = append(c.mws, mwFromHandler(h)) + } c.run() } } diff --git a/pkg/web/webtest/middleware.go b/pkg/web/webtest/middleware.go new file mode 100644 index 00000000000..8c67581259d --- /dev/null +++ b/pkg/web/webtest/middleware.go @@ -0,0 +1,73 @@ +package webtest + +import ( + "net/http" + "net/http/httptest" + + "github.com/grafana/grafana/pkg/web" +) + +type Context struct { + Req *http.Request + Rw http.ResponseWriter + Next http.Handler +} + +// Middleware is a utility for testing middlewares +type Middleware struct { + // Before are run ahead of the returned context + Before []web.Handler + // After are part of the http.Handler chain + After []web.Handler + // The actual handler at the end of the chain + Handler web.Handler +} + +// MiddlewareContext returns a *http.Request, http.ResponseWriter and http.Handler +// exactly as if it was passed to a middleware +func MiddlewareContext(test Middleware, req *http.Request) *Context { + m := web.New() + + // pkg/web requires the chain to write an HTTP response. + // While this ensures a basic amount of correctness for real handler chains, + // it is naturally incompatible with this package, as we terminate the chain early to pass its + // state to the surrounding test. + // By replacing the http.ResponseWriter and writing to the old one we make pkg/web happy. + m.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + + rw := web.Rw(httptest.NewRecorder(), r) + next.ServeHTTP(rw, r) + }) + }) + + for _, mw := range test.Before { + m.Use(mw) + } + + ch := make(chan *Context) + m.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ch <- &Context{ + Req: r, + Rw: w, + Next: next, + } + }) + }) + + for _, mw := range test.After { + m.Use(mw) + } + + // set the provided (or noop) handler to exactly the queried path + handler := test.Handler + if handler == nil { + handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + } + m.Handle(req.Method, req.URL.RequestURI(), []web.Handler{handler}) + go m.ServeHTTP(httptest.NewRecorder(), req) + + return <-ch +} diff --git a/pkg/web/webtest/webtest.go b/pkg/web/webtest/webtest.go index 849a7128d91..c0cd02767e7 100644 --- a/pkg/web/webtest/webtest.go +++ b/pkg/web/webtest/webtest.go @@ -37,7 +37,7 @@ func NewServer(t testing.TB, routeRegister routing.RouteRegister) *Server { c.Req = c.Req.WithContext(ctxkey.Set(c.Req.Context(), initCtx)) }) - m.Use(requestContextMiddleware()) + m.UseMiddleware(requestContextMiddleware()) routeRegister.Register(m.Router) testServer := httptest.NewServer(m) @@ -126,24 +126,25 @@ func requestContextFromRequest(req *http.Request) *models.ReqContext { return val } -func requestContextMiddleware() web.Handler { - return func(res http.ResponseWriter, req *http.Request) { - c := ctxkey.Get(req.Context()).(*models.ReqContext) +func requestContextMiddleware() web.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c := ctxkey.Get(r.Context()).(*models.ReqContext) - ctx := requestContextFromRequest(req) - if ctx == nil { - c.Next() - return - } + ctx := requestContextFromRequest(r) + if ctx != nil { + c.SignedInUser = ctx.SignedInUser + c.UserToken = ctx.UserToken + c.IsSignedIn = ctx.IsSignedIn + c.IsRenderCall = ctx.IsRenderCall + c.AllowAnonymous = ctx.AllowAnonymous + c.SkipCache = ctx.SkipCache + c.RequestNonce = ctx.RequestNonce + c.PerfmonTimer = ctx.PerfmonTimer + c.LookupTokenErr = ctx.LookupTokenErr + } - c.SignedInUser = ctx.SignedInUser - c.UserToken = ctx.UserToken - c.IsSignedIn = ctx.IsSignedIn - c.IsRenderCall = ctx.IsRenderCall - c.AllowAnonymous = ctx.AllowAnonymous - c.SkipCache = ctx.SkipCache - c.RequestNonce = ctx.RequestNonce - c.PerfmonTimer = ctx.PerfmonTimer - c.LookupTokenErr = ctx.LookupTokenErr + next.ServeHTTP(w, r) + }) } }