pkg/web: closure-style middlewares (#51238)

* pkg/web: closure-style middlewares

Switches the middleware execution model from web.Handlers in a slice to
web.Middleware.
Middlewares are temporarily kept in a slice to preserve ordering, but
prior to execution they are applied, forming a giant call-stack, giving
granular control over the execution flow.

* pkg/middleware: adapt to web.Middleware

* pkg/middleware/recovery: use c.Req over req

c.Req gets updated by future handlers, while req stays static.

The current recovery implementation needs this newer information

* pkg/web: correct middleware ordering

* pkg/webtest: adapt middleware

* pkg/web/hack: set w and r onto web.Context

By adopting std middlewares, it may happen they invoke next(w,r) without
putting their modified w,r into the web.Context, leading old-style
handlers to operate on outdated fields.

pkg/web now takes care of this

* pkg/middleware: selectively use future context

* pkg/web: accept closure-style on Use()

* webtest: Middleware testing

adds a utility function to web/webtest to obtain a http.ResponseWriter,
http.Request and http.Handler the same as a middleware that runs would receive

* *: cleanup

* pkg/web: don't wrap Middleware from Router

* pkg/web: require chain to write response

* *: remove temp files

* webtest: don't require chain write

* *: cleanup
This commit is contained in:
sh0rez 2022-08-09 14:58:50 +02:00 committed by GitHub
parent 3893c46976
commit 534ece064b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 357 additions and 264 deletions

View File

@ -526,7 +526,7 @@ func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() {
m.UseMiddleware(middleware.Gziper())
}
m.Use(middleware.Recovery(hs.Cfg))
m.UseMiddleware(middleware.Recovery(hs.Cfg))
m.UseMiddleware(hs.Csrf.Middleware())
hs.mapStatic(m, hs.Cfg.StaticRootPath, "build", "public/build")

View File

@ -4,7 +4,6 @@ package response
//NOTE: This file belongs into pkg/web, but due to cyclic imports that are hard to resolve at the current time, it temporarily lives here.
import (
"context"
"fmt"
"net/http"
@ -24,23 +23,25 @@ type (
func wrap_handler(h web.Handler) http.HandlerFunc {
switch handle := h.(type) {
case http.HandlerFunc:
return handle
case handlerStd:
return handle
case handlerStdCtx:
return func(w http.ResponseWriter, r *http.Request) {
handle(w, r, web.FromContext(r.Context()))
handle(w, r, webCtx(w, r))
}
case handlerStdReqCtx:
return func(w http.ResponseWriter, r *http.Request) {
handle(w, r, getReqCtx(r.Context()))
handle(w, r, reqCtx(w, r))
}
case handlerReqCtx:
return func(w http.ResponseWriter, r *http.Request) {
handle(getReqCtx(r.Context()))
handle(reqCtx(w, r))
}
case handlerReqCtxRes:
return func(w http.ResponseWriter, r *http.Request) {
ctx := getReqCtx(r.Context())
ctx := reqCtx(w, r)
res := handle(ctx)
if res != nil {
res.WriteTo(ctx)
@ -48,15 +49,27 @@ func wrap_handler(h web.Handler) http.HandlerFunc {
}
case handlerCtx:
return func(w http.ResponseWriter, r *http.Request) {
handle(web.FromContext(r.Context()))
handle(webCtx(w, r))
}
}
panic(fmt.Sprintf("unexpected handler type: %T", h))
}
func getReqCtx(ctx context.Context) *models.ReqContext {
reqCtx, ok := ctx.Value(ctxkey.Key{}).(*models.ReqContext)
func webCtx(w http.ResponseWriter, r *http.Request) *web.Context {
ctx := web.FromContext(r.Context())
if ctx == nil {
panic("no *web.Context found")
}
ctx.Req = r
ctx.Resp = web.Rw(w, r)
return ctx
}
func reqCtx(w http.ResponseWriter, r *http.Request) *models.ReqContext {
wCtx := webCtx(w, r)
reqCtx, ok := wCtx.Req.Context().Value(ctxkey.Key{}).(*models.ReqContext)
if !ok {
panic("no *models.ReqContext found")
}

View File

@ -27,50 +27,52 @@ import (
"github.com/grafana/grafana/pkg/web"
)
func Logger(cfg *setting.Cfg) web.Handler {
return func(res http.ResponseWriter, req *http.Request, c *web.Context) {
start := time.Now()
func Logger(cfg *setting.Cfg) web.Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
rw := res.(web.ResponseWriter)
c.Next()
rw := web.Rw(w, r)
next.ServeHTTP(rw, r)
timeTaken := time.Since(start) / time.Millisecond
duration := time.Since(start).String()
ctx := contexthandler.FromContext(c.Req.Context())
if ctx != nil && ctx.PerfmonTimer != nil {
ctx.PerfmonTimer.Observe(float64(timeTaken))
}
status := rw.Status()
if status == 200 || status == 304 {
if !cfg.RouterLogging {
return
}
}
if ctx != nil {
logParams := []interface{}{
"method", req.Method,
"path", req.URL.Path,
"status", status,
"remote_addr", c.RemoteAddr(),
"time_ms", int64(timeTaken),
"duration", duration,
"size", rw.Size(),
"referer", SanitizeURL(ctx, req.Referer()),
timeTaken := time.Since(start) / time.Millisecond
duration := time.Since(start).String()
ctx := contexthandler.FromContext(r.Context())
if ctx != nil && ctx.PerfmonTimer != nil {
ctx.PerfmonTimer.Observe(float64(timeTaken))
}
traceID := tracing.TraceIDFromContext(ctx.Req.Context(), false)
if traceID != "" {
logParams = append(logParams, "traceID", traceID)
status := rw.Status()
if status == 200 || status == 304 {
if !cfg.RouterLogging {
return
}
}
if status >= 500 {
ctx.Logger.Error("Request Completed", logParams...)
} else {
ctx.Logger.Info("Request Completed", logParams...)
if ctx != nil {
logParams := []interface{}{
"method", r.Method,
"path", r.URL.Path,
"status", status,
"remote_addr", ctx.RemoteAddr(),
"time_ms", int64(timeTaken),
"duration", duration,
"size", rw.Size(),
"referer", SanitizeURL(ctx, r.Referer()),
}
traceID := tracing.TraceIDFromContext(ctx.Req.Context(), false)
if traceID != "" {
logParams = append(logParams, "traceID", traceID)
}
if status >= 500 {
ctx.Logger.Error("Request Completed", logParams...)
} else {
ctx.Logger.Info("Request Completed", logParams...)
}
}
}
})
}
}

View File

@ -647,6 +647,9 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
sc.context = c
if sc.handlerFunc != nil {
sc.handlerFunc(sc.context)
if !c.Resp.Written() {
c.Resp.WriteHeader(http.StatusOK)
}
} else {
t.Log("Returning JSON OK")
resp := make(map[string]interface{})

View File

@ -102,69 +102,73 @@ func function(pc uintptr) []byte {
// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
// While Martini is in development mode, Recovery will also output the panic as HTML.
func Recovery(cfg *setting.Cfg) web.Handler {
return func(c *web.Context) {
defer func() {
if r := recover(); r != nil {
var panicLogger log.Logger
panicLogger = log.New("recovery")
// try to get request logger
ctx := contexthandler.FromContext(c.Req.Context())
if ctx != nil {
panicLogger = ctx.Logger
}
func Recovery(cfg *setting.Cfg) web.Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
c := web.FromContext(req.Context())
if err, ok := r.(error); ok {
// http.ErrAbortHandler is suppressed by default in the http package
// and used as a signal for aborting requests. Suppresses stacktrace
// since it doesn't add any important information.
if errors.Is(err, http.ErrAbortHandler) {
panicLogger.Error("Request error", "error", err)
defer func() {
if r := recover(); r != nil {
var panicLogger log.Logger
panicLogger = log.New("recovery")
// try to get request logger
ctx := contexthandler.FromContext(c.Req.Context())
if ctx != nil {
panicLogger = ctx.Logger
}
if err, ok := r.(error); ok {
// http.ErrAbortHandler is suppressed by default in the http package
// and used as a signal for aborting requests. Suppresses stacktrace
// since it doesn't add any important information.
if errors.Is(err, http.ErrAbortHandler) {
panicLogger.Error("Request error", "error", err)
return
}
}
stack := stack(3)
panicLogger.Error("Request error", "error", r, "stack", string(stack))
// if response has already been written, skip.
if c.Resp.Written() {
return
}
}
stack := stack(3)
panicLogger.Error("Request error", "error", r, "stack", string(stack))
data := struct {
Title string
AppTitle string
AppSubUrl string
Theme string
ErrorMsg string
}{"Server Error", "Grafana", cfg.AppSubURL, cfg.DefaultTheme, ""}
// if response has already been written, skip.
if c.Resp.Written() {
return
}
if setting.Env == setting.Dev {
if err, ok := r.(error); ok {
data.Title = err.Error()
}
data := struct {
Title string
AppTitle string
AppSubUrl string
Theme string
ErrorMsg string
}{"Server Error", "Grafana", cfg.AppSubURL, cfg.DefaultTheme, ""}
if setting.Env == setting.Dev {
if err, ok := r.(error); ok {
data.Title = err.Error()
data.ErrorMsg = string(stack)
}
data.ErrorMsg = string(stack)
}
if ctx != nil && ctx.IsApiRequest() {
resp := make(map[string]interface{})
resp["message"] = "Internal Server Error - Check the Grafana server logs for the detailed error message."
if ctx != nil && ctx.IsApiRequest() {
resp := make(map[string]interface{})
resp["message"] = "Internal Server Error - Check the Grafana server logs for the detailed error message."
if data.ErrorMsg != "" {
resp["error"] = fmt.Sprintf("%v - %v", data.Title, data.ErrorMsg)
} else {
resp["error"] = data.Title
}
if data.ErrorMsg != "" {
resp["error"] = fmt.Sprintf("%v - %v", data.Title, data.ErrorMsg)
ctx.JSON(500, resp)
} else {
resp["error"] = data.Title
ctx.HTML(500, cfg.ErrTemplateName, data)
}
c.JSON(500, resp)
} else {
c.HTML(500, cfg.ErrTemplateName, data)
}
}
}()
}()
c.Next()
next.ServeHTTP(w, req)
})
}
}

View File

@ -59,7 +59,7 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) {
require.NoError(t, err)
sc.m = web.New()
sc.m.Use(Recovery(cfg))
sc.m.UseMiddleware(Recovery(cfg))
sc.m.Use(AddDefaultResponseHeaders(cfg))
sc.m.UseMiddleware(web.Renderer(viewsPath, "[[", "]]"))

View File

@ -46,57 +46,60 @@ func init() {
}
// RequestMetrics is a middleware handler that instruments the request.
func RequestMetrics(features featuremgmt.FeatureToggles) web.Handler {
func RequestMetrics(features featuremgmt.FeatureToggles) web.Middleware {
log := log.New("middleware.request-metrics")
return func(res http.ResponseWriter, req *http.Request, c *web.Context) {
rw := res.(web.ResponseWriter)
now := time.Now()
httpRequestsInFlight.Inc()
defer httpRequestsInFlight.Dec()
c.Next()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rw := web.Rw(w, r)
now := time.Now()
httpRequestsInFlight.Inc()
defer httpRequestsInFlight.Dec()
next.ServeHTTP(w, r)
status := rw.Status()
code := sanitizeCode(status)
status := rw.Status()
code := sanitizeCode(status)
handler := "unknown"
if routeOperation, exists := routeOperationName(c.Req); exists {
handler = routeOperation
} else {
// if grafana does not recognize the handler and returns 404 we should register it as `notfound`
if status == http.StatusNotFound {
handler = "notfound"
handler := "unknown"
// TODO: do not depend on web.Context from the future
if routeOperation, exists := routeOperationName(web.FromContext(r.Context()).Req); exists {
handler = routeOperation
} else {
// log requests where we could not identify handler so we can register them.
if features.IsEnabled(featuremgmt.FlagLogRequestsInstrumentedAsUnknown) {
log.Warn("request instrumented as unknown", "path", c.Req.URL.Path, "status_code", status)
// if grafana does not recognize the handler and returns 404 we should register it as `notfound`
if status == http.StatusNotFound {
handler = "notfound"
} else {
// log requests where we could not identify handler so we can register them.
if features.IsEnabled(featuremgmt.FlagLogRequestsInstrumentedAsUnknown) {
log.Warn("request instrumented as unknown", "path", r.URL.Path, "status_code", status)
}
}
}
}
// avoiding the sanitize functions for in the new instrumentation
// since they dont make much sense. We should remove them later.
histogram := httpRequestDurationHistogram.
WithLabelValues(handler, code, req.Method)
if traceID := tracing.TraceIDFromContext(c.Req.Context(), true); traceID != "" {
// Need to type-convert the Observer to an
// ExemplarObserver. This will always work for a
// HistogramVec.
histogram.(prometheus.ExemplarObserver).ObserveWithExemplar(
time.Since(now).Seconds(), prometheus.Labels{"traceID": traceID},
)
return
}
histogram.Observe(time.Since(now).Seconds())
// avoiding the sanitize functions for in the new instrumentation
// since they dont make much sense. We should remove them later.
histogram := httpRequestDurationHistogram.
WithLabelValues(handler, code, r.Method)
if traceID := tracing.TraceIDFromContext(r.Context(), true); traceID != "" {
// Need to type-convert the Observer to an
// ExemplarObserver. This will always work for a
// HistogramVec.
histogram.(prometheus.ExemplarObserver).ObserveWithExemplar(
time.Since(now).Seconds(), prometheus.Labels{"traceID": traceID},
)
return
}
histogram.Observe(time.Since(now).Seconds())
switch {
case strings.HasPrefix(req.RequestURI, "/api/datasources/proxy"):
countProxyRequests(status)
case strings.HasPrefix(req.RequestURI, "/api/"):
countApiRequests(status)
default:
countPageRequests(status)
}
switch {
case strings.HasPrefix(r.RequestURI, "/api/datasources/proxy"):
countProxyRequests(status)
case strings.HasPrefix(r.RequestURI, "/api/"):
countApiRequests(status)
default:
countPageRequests(status)
}
})
}
}

View File

@ -61,37 +61,38 @@ func routeOperationName(req *http.Request) (string, bool) {
return "", false
}
func RequestTracing(tracer tracing.Tracer) web.Handler {
return func(res http.ResponseWriter, req *http.Request, c *web.Context) {
if strings.HasPrefix(c.Req.URL.Path, "/public/") ||
c.Req.URL.Path == "/robots.txt" ||
c.Req.URL.Path == "/favicon.ico" {
c.Next()
return
}
func RequestTracing(tracer tracing.Tracer) web.Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if strings.HasPrefix(req.URL.Path, "/public/") || req.URL.Path == "/robots.txt" || req.URL.Path == "/favicon.ico" {
next.ServeHTTP(w, req)
return
}
rw := res.(web.ResponseWriter)
rw := web.Rw(w, req)
wireContext := otel.GetTextMapPropagator().Extract(req.Context(), propagation.HeaderCarrier(req.Header))
ctx, span := tracer.Start(req.Context(), fmt.Sprintf("HTTP %s %s", req.Method, req.URL.Path), trace.WithLinks(trace.LinkFromContext(wireContext)))
wireContext := otel.GetTextMapPropagator().Extract(req.Context(), propagation.HeaderCarrier(req.Header))
ctx, span := tracer.Start(req.Context(), fmt.Sprintf("HTTP %s %s", req.Method, req.URL.Path), trace.WithLinks(trace.LinkFromContext(wireContext)))
c.Req = req.WithContext(ctx)
c.Next()
req = req.WithContext(ctx)
next.ServeHTTP(w, req)
// Only call span.Finish when a route operation name have been set,
// meaning that not set the span would not be reported.
if routeOperation, exists := routeOperationName(c.Req); exists {
defer span.End()
span.SetName(fmt.Sprintf("HTTP %s %s", req.Method, routeOperation))
}
// Only call span.Finish when a route operation name have been set,
// meaning that not set the span would not be reported.
// TODO: do not depend on web.Context from the future
if routeOperation, exists := routeOperationName(web.FromContext(req.Context()).Req); exists {
defer span.End()
span.SetName(fmt.Sprintf("HTTP %s %s", req.Method, routeOperation))
}
status := rw.Status()
status := rw.Status()
span.SetAttributes("http.status_code", status, attribute.Int("http.status_code", status))
span.SetAttributes("http.url", req.RequestURI, attribute.String("http.url", req.RequestURI))
span.SetAttributes("http.method", req.Method, attribute.String("http.method", req.Method))
if status >= 400 {
span.SetStatus(codes.Error, fmt.Sprintf("error with HTTP status code %s", strconv.Itoa(status)))
}
span.SetAttributes("http.status_code", status, attribute.Int("http.status_code", status))
span.SetAttributes("http.url", req.RequestURI, attribute.String("http.url", req.RequestURI))
span.SetAttributes("http.method", req.Method, attribute.String("http.method", req.Method))
if status >= 400 {
span.SetStatus(codes.Error, fmt.Sprintf("error with HTTP status code %s", strconv.Itoa(status)))
}
})
}
}

View File

@ -67,6 +67,7 @@ func TestMiddleware(t *testing.T) {
endpointCalled := false
server.Get("/", func(c *models.ReqContext) {
endpointCalled = true
c.Resp.WriteHeader(http.StatusOK)
})
request, err := http.NewRequest(http.MethodGet, "/", nil)

View File

@ -50,10 +50,10 @@ func getRequest(t *testing.T, url string, expStatusCode int) *http.Response {
t.Helper()
// nolint:gosec
resp, err := http.Get(url)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, resp.Body.Close())
})
require.NoError(t, err)
if expStatusCode != resp.StatusCode {
b, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)

View File

@ -29,8 +29,7 @@ import (
// Context represents the runtime context of current request of Macaron instance.
// It is the integration of most frequently used middlewares and helper methods.
type Context struct {
handlers []http.Handler
index int
mws []Middleware
*Router
Req *http.Request
@ -39,30 +38,20 @@ type Context struct {
logger log.Logger
}
func (ctx *Context) handler() http.Handler {
if ctx.index < len(ctx.handlers) {
return ctx.handlers[ctx.index]
}
if ctx.index == len(ctx.handlers) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
}
panic("invalid index for context handler")
}
// Next runs the next handler in the context chain
func (ctx *Context) Next() {
ctx.index++
ctx.run()
}
func (ctx *Context) run() {
for ctx.index <= len(ctx.handlers) {
ctx.handler().ServeHTTP(ctx.Resp, ctx.Req)
h := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
for i := len(ctx.mws) - 1; i >= 0; i-- {
h = ctx.mws[i](h)
}
ctx.index++
if ctx.Resp.Written() {
return
}
rw := ctx.Resp
h.ServeHTTP(ctx.Resp, ctx.Req)
// Prevent the handler chain from not writing anything.
// This indicates nearly always that a middleware is misbehaving and not calling its next.ServeHTTP().
// In rare cases where a blank http.StatusOK without any body is wished, explicitly state that using w.WriteStatus(http.StatusOK)
if !rw.Written() {
panic("chain did not write HTTP response")
}
}

View File

@ -53,28 +53,16 @@ type Handler interface{}
//go:linkname hack_wrap github.com/grafana/grafana/pkg/api/response.wrap_handler
func hack_wrap(Handler) http.HandlerFunc
// validateAndWrapHandler makes sure a handler is a callable function, it panics if not.
// When the handler is also potential to be any built-in inject.FastInvoker,
// it wraps the handler automatically to have some performance gain.
func validateAndWrapHandler(h Handler) http.Handler {
// wrapHandler turns any supported handler type into a http.Handler by wrapping it accordingly
func wrapHandler(h Handler) http.Handler {
return hack_wrap(h)
}
// validateAndWrapHandlers preforms validation and wrapping for each input handler.
// It accepts an optional wrapper function to perform custom wrapping on handlers.
func validateAndWrapHandlers(handlers []Handler) []http.Handler {
wrappedHandlers := make([]http.Handler, len(handlers))
for i, h := range handlers {
wrappedHandlers[i] = validateAndWrapHandler(h)
}
return wrappedHandlers
}
// Macaron represents the top level web application.
// Injector methods can be invoked to map services on a global level.
type Macaron struct {
handlers []http.Handler
// handlers []http.Handler
mws []Middleware
urlPrefix string // For suburl support.
*Router
@ -119,43 +107,50 @@ func SetURLParams(r *http.Request, vars map[string]string) *http.Request {
return r.WithContext(context.WithValue(r.Context(), paramsKey{}, vars))
}
// UseMiddleware is a traditional approach to writing middleware in Go.
// A middleware is a function that has a reference to the next handler in the chain
// and returns the actual middleware handler, that may do its job and optionally
// call next.
// Due to how Macaron handles/injects requests and responses we patch the web.Context
// to use the new ResponseWriter and http.Request here. The caller may only call
// `next.ServeHTTP(rw, req)` to pass a modified response writer and/or a request to the
// further middlewares in the chain.
func (m *Macaron) UseMiddleware(middleware func(http.Handler) http.Handler) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
c := FromContext(req.Context())
c.Req = req
if mrw, ok := rw.(*responseWriter); ok {
c.Resp = mrw
} else {
c.Resp = NewResponseWriter(req.Method, rw)
}
c.Next()
})
m.handlers = append(m.handlers, middleware(next))
type Middleware = func(next http.Handler) http.Handler
// UseMiddleware registers the given Middleware
func (m *Macaron) UseMiddleware(mw Middleware) {
m.mws = append(m.mws, mw)
}
// Use adds a middleware Handler to the stack,
// and panics if the handler is not a callable func.
// Middleware Handlers are invoked in the order that they are added.
func (m *Macaron) Use(handler Handler) {
h := validateAndWrapHandler(handler)
m.handlers = append(m.handlers, h)
// Use registers the provided Handler as a middleware.
// The argument may be any supported handler or the Middleware type
// Deprecated: use UseMiddleware instead
func (m *Macaron) Use(h Handler) {
m.mws = append(m.mws, mwFromHandler(h))
}
func mwFromHandler(handler Handler) Middleware {
if mw, ok := handler.(Middleware); ok {
return mw
}
h := wrapHandler(handler)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mrw, ok := w.(*responseWriter)
if !ok {
mrw = NewResponseWriter(r.Method, w).(*responseWriter)
}
h.ServeHTTP(mrw, r)
if mrw.Written() {
return
}
ctx := r.Context().Value(macaronContextKey{}).(*Context)
next.ServeHTTP(ctx.Resp, ctx.Req)
})
}
}
func (m *Macaron) createContext(rw http.ResponseWriter, req *http.Request) *Context {
c := &Context{
handlers: m.handlers,
index: 0,
Router: m.Router,
Resp: NewResponseWriter(req.Method, rw),
logger: log.New("macaron.context"),
mws: m.mws,
Router: m.Router,
Resp: NewResponseWriter(req.Method, rw),
logger: log.New("macaron.context"),
}
c.Req = req.WithContext(context.WithValue(req.Context(), macaronContextKey{}, c))

View File

@ -26,7 +26,7 @@ import (
// Renderer is a Middleware that injects a template renderer into the macaron context, enabling ctx.HTML calls in the handlers.
// If MACARON_ENV is set to "development" then templates will be recompiled on every request. For more performance, set the
// MACARON_ENV environment variable to "production".
func Renderer(dir, leftDelim, rightDelim string) func(http.Handler) http.Handler {
func Renderer(dir, leftDelim, rightDelim string) Middleware {
fs := os.DirFS(dir)
t, err := compileTemplates(fs, leftDelim, rightDelim)
if err != nil {

View File

@ -46,6 +46,16 @@ func NewResponseWriter(method string, rw http.ResponseWriter) ResponseWriter {
return &responseWriter{method, rw, 0, 0, nil}
}
// Rw returns a ResponseWriter. If the argument already satisfies the interface,
// it is returned as is, otherwise it is wrapped using NewResponseWriter
func Rw(rw http.ResponseWriter, req *http.Request) ResponseWriter {
if mrw, ok := rw.(ResponseWriter); ok {
return mrw
}
return NewResponseWriter(req.Method, rw)
}
type responseWriter struct {
method string
http.ResponseWriter

View File

@ -146,13 +146,12 @@ func (r *Router) Handle(method string, pattern string, handlers []Handler) {
h = append(h, handlers...)
handlers = h
}
httpHandlers := validateAndWrapHandlers(handlers)
r.handle(method, pattern, func(resp http.ResponseWriter, req *http.Request, params map[string]string) {
c := r.m.createContext(resp, SetURLParams(req, params))
c.handlers = make([]http.Handler, 0, len(r.m.handlers)+len(handlers))
c.handlers = append(c.handlers, r.m.handlers...)
c.handlers = append(c.handlers, httpHandlers...)
for _, h := range handlers {
c.mws = append(c.mws, mwFromHandler(h))
}
c.run()
})
}
@ -194,12 +193,11 @@ func (r *Router) Any(pattern string, h ...Handler) { r.Handle("*", pattern, h) }
// found. If it is not set, http.NotFound is used.
// Be sure to set 404 response code in your handler.
func (r *Router) NotFound(handlers ...Handler) {
httpHandlers := validateAndWrapHandlers(handlers)
r.notFound = func(rw http.ResponseWriter, req *http.Request) {
c := r.m.createContext(rw, req)
c.handlers = make([]http.Handler, 0, len(r.m.handlers)+len(handlers))
c.handlers = append(c.handlers, r.m.handlers...)
c.handlers = append(c.handlers, httpHandlers...)
for _, h := range handlers {
c.mws = append(c.mws, mwFromHandler(h))
}
c.run()
}
}

View File

@ -0,0 +1,73 @@
package webtest
import (
"net/http"
"net/http/httptest"
"github.com/grafana/grafana/pkg/web"
)
type Context struct {
Req *http.Request
Rw http.ResponseWriter
Next http.Handler
}
// Middleware is a utility for testing middlewares
type Middleware struct {
// Before are run ahead of the returned context
Before []web.Handler
// After are part of the http.Handler chain
After []web.Handler
// The actual handler at the end of the chain
Handler web.Handler
}
// MiddlewareContext returns a *http.Request, http.ResponseWriter and http.Handler
// exactly as if it was passed to a middleware
func MiddlewareContext(test Middleware, req *http.Request) *Context {
m := web.New()
// pkg/web requires the chain to write an HTTP response.
// While this ensures a basic amount of correctness for real handler chains,
// it is naturally incompatible with this package, as we terminate the chain early to pass its
// state to the surrounding test.
// By replacing the http.ResponseWriter and writing to the old one we make pkg/web happy.
m.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
rw := web.Rw(httptest.NewRecorder(), r)
next.ServeHTTP(rw, r)
})
})
for _, mw := range test.Before {
m.Use(mw)
}
ch := make(chan *Context)
m.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ch <- &Context{
Req: r,
Rw: w,
Next: next,
}
})
})
for _, mw := range test.After {
m.Use(mw)
}
// set the provided (or noop) handler to exactly the queried path
handler := test.Handler
if handler == nil {
handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
}
m.Handle(req.Method, req.URL.RequestURI(), []web.Handler{handler})
go m.ServeHTTP(httptest.NewRecorder(), req)
return <-ch
}

View File

@ -37,7 +37,7 @@ func NewServer(t testing.TB, routeRegister routing.RouteRegister) *Server {
c.Req = c.Req.WithContext(ctxkey.Set(c.Req.Context(), initCtx))
})
m.Use(requestContextMiddleware())
m.UseMiddleware(requestContextMiddleware())
routeRegister.Register(m.Router)
testServer := httptest.NewServer(m)
@ -126,24 +126,25 @@ func requestContextFromRequest(req *http.Request) *models.ReqContext {
return val
}
func requestContextMiddleware() web.Handler {
return func(res http.ResponseWriter, req *http.Request) {
c := ctxkey.Get(req.Context()).(*models.ReqContext)
func requestContextMiddleware() web.Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c := ctxkey.Get(r.Context()).(*models.ReqContext)
ctx := requestContextFromRequest(req)
if ctx == nil {
c.Next()
return
}
ctx := requestContextFromRequest(r)
if ctx != nil {
c.SignedInUser = ctx.SignedInUser
c.UserToken = ctx.UserToken
c.IsSignedIn = ctx.IsSignedIn
c.IsRenderCall = ctx.IsRenderCall
c.AllowAnonymous = ctx.AllowAnonymous
c.SkipCache = ctx.SkipCache
c.RequestNonce = ctx.RequestNonce
c.PerfmonTimer = ctx.PerfmonTimer
c.LookupTokenErr = ctx.LookupTokenErr
}
c.SignedInUser = ctx.SignedInUser
c.UserToken = ctx.UserToken
c.IsSignedIn = ctx.IsSignedIn
c.IsRenderCall = ctx.IsRenderCall
c.AllowAnonymous = ctx.AllowAnonymous
c.SkipCache = ctx.SkipCache
c.RequestNonce = ctx.RequestNonce
c.PerfmonTimer = ctx.PerfmonTimer
c.LookupTokenErr = ctx.LookupTokenErr
next.ServeHTTP(w, r)
})
}
}