Chore: Use response limit middleware from SDK (#83915)

This commit is contained in:
Andres Martinez Gotor
2024-03-13 10:14:16 +01:00
committed by GitHub
parent ecd6de826a
commit c061cc33cc
12 changed files with 28 additions and 201 deletions

View File

@@ -27,7 +27,7 @@ func New(cfg *setting.Cfg, validator validations.PluginRequestValidator, tracer
SetUserAgentMiddleware(cfg.DataProxyUserAgent),
sdkhttpclient.BasicAuthenticationMiddleware(),
sdkhttpclient.CustomHeadersMiddleware(),
ResponseLimitMiddleware(cfg.ResponseLimit),
sdkhttpclient.ResponseLimitMiddleware(cfg.ResponseLimit),
RedirectLimitMiddleware(validator),
}

View File

@@ -33,7 +33,7 @@ func TestHTTPClientProvider(t *testing.T) {
require.Equal(t, SetUserAgentMiddlewareName, o.Middlewares[3].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.BasicAuthenticationMiddlewareName, o.Middlewares[4].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.CustomHeadersMiddlewareName, o.Middlewares[5].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, ResponseLimitMiddlewareName, o.Middlewares[6].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.ResponseLimitMiddlewareName, o.Middlewares[6].(sdkhttpclient.MiddlewareName).MiddlewareName())
})
t.Run("When creating new provider and SigV4 is enabled should apply expected middleware", func(t *testing.T) {
@@ -57,7 +57,7 @@ func TestHTTPClientProvider(t *testing.T) {
require.Equal(t, SetUserAgentMiddlewareName, o.Middlewares[3].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.BasicAuthenticationMiddlewareName, o.Middlewares[4].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.CustomHeadersMiddlewareName, o.Middlewares[5].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, ResponseLimitMiddlewareName, o.Middlewares[6].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.ResponseLimitMiddlewareName, o.Middlewares[6].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, SigV4MiddlewareName, o.Middlewares[8].(sdkhttpclient.MiddlewareName).MiddlewareName())
})
@@ -82,7 +82,7 @@ func TestHTTPClientProvider(t *testing.T) {
require.Equal(t, SetUserAgentMiddlewareName, o.Middlewares[3].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.BasicAuthenticationMiddlewareName, o.Middlewares[4].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.CustomHeadersMiddlewareName, o.Middlewares[5].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, ResponseLimitMiddlewareName, o.Middlewares[6].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, sdkhttpclient.ResponseLimitMiddlewareName, o.Middlewares[6].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, HostRedirectValidationMiddlewareName, o.Middlewares[7].(sdkhttpclient.MiddlewareName).MiddlewareName())
require.Equal(t, HTTPLoggerMiddlewareName, o.Middlewares[8].(sdkhttpclient.MiddlewareName).MiddlewareName())
})

View File

@@ -1,31 +0,0 @@
package httpclientprovider
import (
"net/http"
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/infra/httpclient"
)
// ResponseLimitMiddlewareName is the middleware name used by ResponseLimitMiddleware.
const ResponseLimitMiddlewareName = "response-limit"
func ResponseLimitMiddleware(limit int64) sdkhttpclient.Middleware {
return sdkhttpclient.NamedMiddlewareFunc(ResponseLimitMiddlewareName, func(opts sdkhttpclient.Options, next http.RoundTripper) http.RoundTripper {
if limit <= 0 {
return next
}
return sdkhttpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
res, err := next.RoundTrip(req)
if err != nil {
return nil, err
}
if res != nil && res.StatusCode != http.StatusSwitchingProtocols {
res.Body = httpclient.MaxBytesReader(res.Body, limit)
}
return res, nil
})
})
}

View File

@@ -1,60 +0,0 @@
package httpclientprovider
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"strings"
"testing"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/stretchr/testify/require"
)
func TestResponseLimitMiddleware(t *testing.T) {
tcs := []struct {
limit int64
bodyLength int
body string
err error
}{
{limit: 1, bodyLength: 1, body: "d", err: errors.New("error: http: response body too large, response limit is set to: 1")},
{limit: 1000000, bodyLength: 5, body: "dummy", err: nil},
{limit: 0, bodyLength: 5, body: "dummy", err: nil},
}
for _, tc := range tcs {
t.Run(fmt.Sprintf("Test ResponseLimitMiddleware with limit: %d", tc.limit), func(t *testing.T) {
finalRoundTripper := httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{StatusCode: http.StatusOK, Request: req, Body: io.NopCloser(strings.NewReader("dummy"))}, nil
})
mw := ResponseLimitMiddleware(tc.limit)
rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper)
require.NotNil(t, rt)
middlewareName, ok := mw.(httpclient.MiddlewareName)
require.True(t, ok)
require.Equal(t, ResponseLimitMiddlewareName, middlewareName.MiddlewareName())
ctx := context.Background()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://test.com/query", nil)
require.NoError(t, err)
res, err := rt.RoundTrip(req)
require.NoError(t, err)
require.NotNil(t, res)
require.NotNil(t, res.Body)
require.NoError(t, res.Body.Close())
bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
require.EqualError(t, tc.err, err.Error())
} else {
require.NoError(t, tc.err)
}
require.Len(t, bodyBytes, tc.bodyLength)
require.Equal(t, string(bodyBytes), tc.body)
})
}
}

View File

@@ -1,66 +0,0 @@
package httpclient
import (
"errors"
"fmt"
"io"
)
// Similar implementation to http/net MaxBytesReader
// https://pkg.go.dev/net/http#MaxBytesReader
// What's happening differently here, is that the field that
// is limited is the response and not the request, thus
// the error handling/message needed to be accurate.
// ErrResponseBodyTooLarge indicates response body is too large
var ErrResponseBodyTooLarge = errors.New("http: response body too large")
// MaxBytesReader is similar to io.LimitReader but is intended for
// limiting the size of incoming request bodies. In contrast to
// io.LimitReader, MaxBytesReader's result is a ReadCloser, returns a
// non-EOF error for a Read beyond the limit, and closes the
// underlying reader when its Close method is called.
//
// MaxBytesReader prevents clients from accidentally or maliciously
// sending a large request and wasting server resources.
func MaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser {
return &maxBytesReader{r: r, n: n}
}
type maxBytesReader struct {
r io.ReadCloser // underlying reader
n int64 // max bytes remaining
err error // sticky error
}
func (l *maxBytesReader) Read(p []byte) (n int, err error) {
if l.err != nil {
return 0, l.err
}
if len(p) == 0 {
return 0, nil
}
// If they asked for a 32KB byte read but only 5 bytes are
// remaining, no need to read 32KB. 6 bytes will answer the
// question of the whether we hit the limit or go past it.
if int64(len(p)) > l.n+1 {
p = p[:l.n+1]
}
n, err = l.r.Read(p)
if int64(n) <= l.n {
l.n -= int64(n)
l.err = err
return n, err
}
n = int(l.n)
l.n = 0
l.err = fmt.Errorf("error: %w, response limit is set to: %d", ErrResponseBodyTooLarge, n)
return n, l.err
}
func (l *maxBytesReader) Close() error {
return l.r.Close()
}

View File

@@ -1,40 +0,0 @@
package httpclient
import (
"errors"
"fmt"
"io"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestMaxBytesReader(t *testing.T) {
tcs := []struct {
limit int64
bodyLength int
body string
err error
}{
{limit: 1, bodyLength: 1, body: "d", err: errors.New("error: http: response body too large, response limit is set to: 1")},
{limit: 1000000, bodyLength: 5, body: "dummy", err: nil},
{limit: 0, bodyLength: 0, body: "", err: errors.New("error: http: response body too large, response limit is set to: 0")},
}
for _, tc := range tcs {
t.Run(fmt.Sprintf("Test MaxBytesReader with limit: %d", tc.limit), func(t *testing.T) {
body := io.NopCloser(strings.NewReader("dummy"))
readCloser := MaxBytesReader(body, tc.limit)
bodyBytes, err := io.ReadAll(readCloser)
if err != nil {
require.EqualError(t, tc.err, err.Error())
} else {
require.NoError(t, tc.err)
}
require.Len(t, bodyBytes, tc.bodyLength)
require.Equal(t, string(bodyBytes), tc.body)
})
}
}

View File

@@ -0,0 +1,15 @@
package datasource
import (
"context"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
)
func contextualMiddlewares(ctx context.Context) context.Context {
cfg := backend.GrafanaConfigFromContext(ctx)
m := httpclient.ResponseLimitMiddleware(cfg.ResponseLimit())
return httpclient.WithContextualMiddleware(ctx, m)
}

View File

@@ -50,6 +50,7 @@ func (r *subHealthREST) Connect(ctx context.Context, name string, opts runtime.O
return nil, err
}
ctx = backend.WithGrafanaConfig(ctx, pluginCtx.GrafanaConfig)
ctx = contextualMiddlewares(ctx)
healthResponse, err := r.builder.client.CheckHealth(ctx, &backend.CheckHealthRequest{
PluginContext: pluginCtx,

View File

@@ -72,6 +72,7 @@ func (r *subQueryREST) Connect(ctx context.Context, name string, opts runtime.Ob
}
ctx = backend.WithGrafanaConfig(ctx, pluginCtx.GrafanaConfig)
ctx = contextualMiddlewares(ctx)
rsp, err := r.builder.client.QueryData(ctx, &backend.QueryDataRequest{
Queries: queries,
PluginContext: pluginCtx,

View File

@@ -51,6 +51,7 @@ func (r *subResourceREST) Connect(ctx context.Context, name string, opts runtime
return nil, err
}
ctx = backend.WithGrafanaConfig(ctx, pluginCtx.GrafanaConfig)
ctx = contextualMiddlewares(ctx)
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body, err := io.ReadAll(req.Body)

View File

@@ -63,6 +63,7 @@ type PluginInstanceCfg struct {
GrafanaVersion string
ConcurrentQueryCount int
ResponseLimit int64
UserFacingDefaultError string
@@ -115,6 +116,7 @@ func ProvidePluginInstanceConfig(cfg *setting.Cfg, settingProvider setting.Provi
SQLDatasourceMaxOpenConnsDefault: cfg.SqlDatasourceMaxOpenConnsDefault,
SQLDatasourceMaxIdleConnsDefault: cfg.SqlDatasourceMaxIdleConnsDefault,
SQLDatasourceMaxConnLifetimeDefault: cfg.SqlDatasourceMaxConnLifetimeDefault,
ResponseLimit: cfg.ResponseLimit,
}, nil
}

View File

@@ -147,5 +147,9 @@ func (s *RequestConfigProvider) PluginRequestConfig(ctx context.Context, pluginI
m[backend.SQLMaxIdleConnsDefault] = strconv.Itoa(s.cfg.SQLDatasourceMaxIdleConnsDefault)
m[backend.SQLMaxConnLifetimeSecondsDefault] = strconv.Itoa(s.cfg.SQLDatasourceMaxConnLifetimeDefault)
if s.cfg.ResponseLimit > 0 {
m[backend.ResponseLimit] = strconv.FormatInt(s.cfg.ResponseLimit, 10)
}
return m
}