mirror of
https://github.com/grafana/grafana.git
synced 2024-11-29 12:14:08 -06:00
234 lines
8.0 KiB
Go
234 lines
8.0 KiB
Go
package proxyutil
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/grafana/grafana/pkg/infra/log"
|
|
"github.com/grafana/grafana/pkg/middleware/requestmeta"
|
|
"github.com/grafana/grafana/pkg/services/contexthandler"
|
|
"github.com/grafana/grafana/pkg/setting"
|
|
)
|
|
|
|
func TestReverseProxy(t *testing.T) {
|
|
t.Run("When proxying a request should enforce request and response constraints", func(t *testing.T) {
|
|
var actualReq *http.Request
|
|
upstream := newUpstreamServer(t, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
actualReq = req
|
|
http.SetCookie(w, &http.Cookie{Name: "test"})
|
|
w.Header().Set("Strict-Transport-Security", "max-age=31536000")
|
|
w.Header().Set("X-Custom-Hdr", "Ok!")
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
t.Cleanup(upstream.Close)
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, upstream.URL, nil)
|
|
req.Header.Set("X-Forwarded-Host", "forwarded.host.com")
|
|
req.Header.Set("X-Forwarded-Port", "8080")
|
|
req.Header.Set("X-Forwarded-Proto", "https")
|
|
req.Header.Set("Origin", "test.com")
|
|
req.Header.Set("Referer", "https://test.com/api")
|
|
req.RemoteAddr = "10.0.0.1"
|
|
|
|
req = req.WithContext(contexthandler.WithAuthHTTPHeaders(req.Context(), setting.NewCfg()))
|
|
req.Header.Set("Authorization", "val")
|
|
|
|
rp := NewReverseProxy(log.New("test"), func(req *http.Request) {
|
|
req.Header.Set("X-KEY", "value")
|
|
})
|
|
require.NotNil(t, rp)
|
|
require.NotNil(t, rp.ModifyResponse)
|
|
rp.ServeHTTP(rec, req)
|
|
|
|
require.NotNil(t, actualReq)
|
|
require.Empty(t, actualReq.Header.Get("X-Forwarded-Host"))
|
|
require.Empty(t, actualReq.Header.Get("X-Forwarded-Port"))
|
|
require.Empty(t, actualReq.Header.Get("X-Forwarded-Proto"))
|
|
require.Equal(t, "10.0.0.1", actualReq.Header.Get("X-Forwarded-For"))
|
|
require.Empty(t, actualReq.Header.Get("Origin"))
|
|
require.Empty(t, actualReq.Header.Get("Referer"))
|
|
require.Equal(t, "https://test.com/api", actualReq.Header.Get("X-Grafana-Referer"))
|
|
require.Equal(t, "value", actualReq.Header.Get("X-KEY"))
|
|
require.Empty(t, actualReq.Header.Get("Authorization"))
|
|
resp := rec.Result()
|
|
require.Empty(t, resp.Cookies())
|
|
require.Equal(t, "sandbox", resp.Header.Get("Content-Security-Policy"))
|
|
require.Contains(t, resp.Header, "X-Custom-Hdr")
|
|
require.NotContains(t, resp.Header, "Strict-Transport-Security")
|
|
require.Contains(t, resp.Header.Get("Via"), "grafana")
|
|
require.NoError(t, resp.Body.Close())
|
|
})
|
|
|
|
t.Run("When proxying a request using WithModifyResponse should call it before default ModifyResponse func", func(t *testing.T) {
|
|
var actualReq *http.Request
|
|
upstream := newUpstreamServer(t, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
actualReq = req
|
|
http.SetCookie(w, &http.Cookie{Name: "test"})
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
t.Cleanup(upstream.Close)
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, upstream.URL, nil)
|
|
rp := NewReverseProxy(
|
|
log.New("test"),
|
|
func(req *http.Request) {
|
|
req.Header.Set("X-KEY", "value")
|
|
},
|
|
WithModifyResponse(func(r *http.Response) error {
|
|
r.Header.Set("X-KEY2", "value2")
|
|
return nil
|
|
}),
|
|
)
|
|
require.NotNil(t, rp)
|
|
require.NotNil(t, rp.ModifyResponse)
|
|
rp.ServeHTTP(rec, req)
|
|
|
|
require.NotNil(t, actualReq)
|
|
require.Equal(t, "value", actualReq.Header.Get("X-KEY"))
|
|
resp := rec.Result()
|
|
require.Empty(t, resp.Cookies())
|
|
require.Equal(t, "sandbox", resp.Header.Get("Content-Security-Policy"))
|
|
require.Equal(t, "value2", resp.Header.Get("X-KEY2"))
|
|
require.NoError(t, resp.Body.Close())
|
|
})
|
|
|
|
t.Run("Error handling should convert status codes depending on what kind of error it is and set downstream status source", func(t *testing.T) {
|
|
timedOutTransport := http.DefaultTransport.(*http.Transport)
|
|
timedOutTransport.ResponseHeaderTimeout = time.Millisecond
|
|
|
|
testCases := []struct {
|
|
desc string
|
|
transport http.RoundTripper
|
|
responseWaitTime time.Duration
|
|
expectedStatusCode int
|
|
}{
|
|
{
|
|
desc: "Cancelled request should return 499 Client closed request",
|
|
transport: &cancelledRoundTripper{},
|
|
expectedStatusCode: StatusClientClosedRequest,
|
|
},
|
|
{
|
|
desc: "Timed out request should return 504 Gateway timeout",
|
|
transport: timedOutTransport,
|
|
responseWaitTime: 100 * time.Millisecond,
|
|
expectedStatusCode: http.StatusGatewayTimeout,
|
|
},
|
|
{
|
|
desc: "Failed request should return 502 Bad gateway",
|
|
transport: &failingRoundTripper{},
|
|
expectedStatusCode: http.StatusBadGateway,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.desc, func(t *testing.T) {
|
|
upstream := newUpstreamServer(t, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
if tc.responseWaitTime > 0 {
|
|
time.Sleep(tc.responseWaitTime)
|
|
}
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
t.Cleanup(upstream.Close)
|
|
rec := httptest.NewRecorder()
|
|
|
|
ctx := requestmeta.SetRequestMetaData(context.Background(), requestmeta.RequestMetaData{
|
|
StatusSource: requestmeta.StatusSourceServer,
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, upstream.URL, nil)
|
|
req = req.WithContext(ctx)
|
|
|
|
rp := NewReverseProxy(
|
|
log.New("test"),
|
|
func(req *http.Request) {},
|
|
WithTransport(tc.transport),
|
|
)
|
|
require.NotNil(t, rp)
|
|
require.NotNil(t, rp.Transport)
|
|
require.Same(t, tc.transport, rp.Transport)
|
|
rp.ServeHTTP(rec, req)
|
|
|
|
resp := rec.Result()
|
|
require.Equal(t, tc.expectedStatusCode, resp.StatusCode)
|
|
require.NoError(t, resp.Body.Close())
|
|
|
|
rmd := requestmeta.GetRequestMetaData(ctx)
|
|
require.Equal(t, requestmeta.StatusSourceDownstream, rmd.StatusSource)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("5xx response status codes should set downstream status source", func(t *testing.T) {
|
|
testCases := []struct {
|
|
status int
|
|
expectedSource requestmeta.StatusSource
|
|
}{
|
|
{status: http.StatusOK, expectedSource: requestmeta.StatusSourceServer},
|
|
{status: http.StatusBadRequest, expectedSource: requestmeta.StatusSourceServer},
|
|
{status: http.StatusForbidden, expectedSource: requestmeta.StatusSourceServer},
|
|
{status: http.StatusUnauthorized, expectedSource: requestmeta.StatusSourceServer},
|
|
{status: http.StatusInternalServerError, expectedSource: requestmeta.StatusSourceDownstream},
|
|
{status: http.StatusBadGateway, expectedSource: requestmeta.StatusSourceDownstream},
|
|
{status: http.StatusGatewayTimeout, expectedSource: requestmeta.StatusSourceDownstream},
|
|
{status: 599, expectedSource: requestmeta.StatusSourceDownstream},
|
|
}
|
|
|
|
for _, testCase := range testCases {
|
|
tc := testCase
|
|
t.Run(fmt.Sprintf("status %d => source %s ", tc.status, tc.expectedSource), func(t *testing.T) {
|
|
upstream := newUpstreamServer(t, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
w.WriteHeader(tc.status)
|
|
}))
|
|
t.Cleanup(upstream.Close)
|
|
rec := httptest.NewRecorder()
|
|
|
|
ctx := requestmeta.SetRequestMetaData(context.Background(), requestmeta.RequestMetaData{
|
|
StatusSource: requestmeta.StatusSourceServer,
|
|
})
|
|
req := httptest.NewRequest(http.MethodGet, upstream.URL, nil)
|
|
req = req.WithContext(ctx)
|
|
|
|
rp := NewReverseProxy(
|
|
log.New("test"),
|
|
func(req *http.Request) {},
|
|
)
|
|
require.NotNil(t, rp)
|
|
rp.ServeHTTP(rec, req)
|
|
|
|
resp := rec.Result()
|
|
require.Equal(t, tc.status, resp.StatusCode)
|
|
require.NoError(t, resp.Body.Close())
|
|
|
|
rmd := requestmeta.GetRequestMetaData(ctx)
|
|
require.Equal(t, tc.expectedSource, rmd.StatusSource)
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
func newUpstreamServer(t *testing.T, handler http.Handler) *httptest.Server {
|
|
t.Helper()
|
|
|
|
upstream := httptest.NewServer(handler)
|
|
return upstream
|
|
}
|
|
|
|
type cancelledRoundTripper struct{}
|
|
|
|
func (cancelledRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
|
|
return nil, context.Canceled
|
|
}
|
|
|
|
type failingRoundTripper struct{}
|
|
|
|
func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
|
|
return nil, errors.New("some error")
|
|
}
|