From 22989acf95fadcfc0b270cd2fcc89499cdc8c8e9 Mon Sep 17 00:00:00 2001 From: Todd Treece <360020+toddtreece@users.noreply.github.com> Date: Fri, 7 Jun 2024 16:10:02 -0400 Subject: [PATCH] K8s: Improve response writer error handling (#88926) --------- Co-authored-by: Diego Augusto Molina --- .../responsewriter/responsewriter.go | 87 +++++++++++++------ .../responsewriter/responsewriter_test.go | 86 ++++++++++++++++-- 2 files changed, 143 insertions(+), 30 deletions(-) diff --git a/pkg/apiserver/endpoints/responsewriter/responsewriter.go b/pkg/apiserver/endpoints/responsewriter/responsewriter.go index 6e635203cfa..150fb7eeb89 100644 --- a/pkg/apiserver/endpoints/responsewriter/responsewriter.go +++ b/pkg/apiserver/endpoints/responsewriter/responsewriter.go @@ -2,9 +2,11 @@ package responsewriter import ( "bufio" + "errors" "fmt" "io" "net/http" + "sync/atomic" "k8s.io/apiserver/pkg/endpoints/responsewriter" "k8s.io/klog/v2" @@ -14,30 +16,38 @@ var _ responsewriter.CloseNotifierFlusher = (*ResponseAdapter)(nil) var _ http.ResponseWriter = (*ResponseAdapter)(nil) var _ io.ReadCloser = (*ResponseAdapter)(nil) +// WrapHandler wraps an http.Handler to return a function that can be used as a [http.RoundTripper]. +// This is used to directly connect the LoopbackConfig [http.RoundTripper] +// with the apiserver's [http.Handler], which avoids the need to start a listener +// for internal clients that use the LoopbackConfig. +// All other requests should not use this wrapper, and should be handled by the +// Grafana HTTP server to ensure that signedInUser middleware is applied. func WrapHandler(handler http.Handler) func(req *http.Request) (*http.Response, error) { // ignore the lint error because the response is passed directly to the client, // so the client will be responsible for closing the response body. //nolint:bodyclose return func(req *http.Request) (*http.Response, error) { w := NewAdapter(req) - resp := w.Response() go func() { handler.ServeHTTP(w, req) if err := w.CloseWriter(); err != nil { klog.Errorf("error closing writer: %v", err) } }() - return resp, nil + + return w.Response() } } // ResponseAdapter is an implementation of [http.ResponseWriter] that allows conversion to a [http.Response]. type ResponseAdapter struct { - req *http.Request - res *http.Response - reader io.ReadCloser - writer io.WriteCloser - buffered *bufio.ReadWriter + req *http.Request + res http.Response + reader io.ReadCloser + writer io.WriteCloser + buffered *bufio.ReadWriter + ready chan struct{} + wroteHeader int32 } // NewAdapter returns an initialized [ResponseAdapter]. @@ -48,7 +58,7 @@ func NewAdapter(req *http.Request) *ResponseAdapter { buffered := bufio.NewReadWriter(reader, writer) return &ResponseAdapter{ req: req, - res: &http.Response{ + res: http.Response{ Proto: req.Proto, ProtoMajor: req.ProtoMajor, ProtoMinor: req.ProtoMinor, @@ -57,6 +67,7 @@ func NewAdapter(req *http.Request) *ResponseAdapter { reader: r, writer: w, buffered: buffered, + ready: make(chan struct{}), } } @@ -68,6 +79,9 @@ func (ra *ResponseAdapter) Header() http.Header { // Write implements [http.ResponseWriter]. func (ra *ResponseAdapter) Write(buf []byte) (int, error) { + // via https://pkg.go.dev/net/http#ResponseWriter.Write + // If WriteHeader is not called explicitly, the first call to Write will trigger an implicit WriteHeader(http.StatusOK). + ra.WriteHeader(http.StatusOK) return ra.buffered.Write(buf) } @@ -78,31 +92,54 @@ func (ra *ResponseAdapter) Read(buf []byte) (int, error) { // WriteHeader implements [http.ResponseWriter]. func (ra *ResponseAdapter) WriteHeader(code int) { - ra.res.StatusCode = code - ra.res.Status = fmt.Sprintf("%03d %s", code, http.StatusText(code)) + if atomic.CompareAndSwapInt32(&ra.wroteHeader, 0, 1) { + ra.res.StatusCode = code + ra.res.Status = fmt.Sprintf("%03d %s", code, http.StatusText(code)) + close(ra.ready) + } } -// Flush implements [http.Flusher]. +// FlushError implements [http.Flusher]. func (ra *ResponseAdapter) Flush() { - if ra.buffered.Writer.Buffered() == 0 { - return - } - - if err := ra.buffered.Writer.Flush(); err != nil { + // We discard io.ErrClosedPipe. This is because as we return the response as + // soon as we have the first write or the status set, the client side with + // the response could potentially call Close on the response body, which + // would cause the reader side of the io.Pipe to be closed. This would cause + // a subsequent call to Write or Flush/FlushError (that have data to write + // to the pipe) to fail with this error. This is expected and legit, and + // this error should be checked by the handler side by either validating the + // error in Write or the one in FlushError. This means it is a + // responsibility of the handler author(s) to handle this error. In other + // cases, we log the error, as it could be potentially not easy to check + // otherwise. + if err := ra.FlushError(); err != nil && !errors.Is(err, io.ErrClosedPipe) { klog.Error("Error flushing response buffer: ", "error", err) } } -// Response returns the [http.Response] generated by the [http.Handler]. -func (ra *ResponseAdapter) Response() *http.Response { - // make sure to set the status code to 200 if the request is a watch - // this is to ensure that client-go uses a streamwatcher: - // https://github.com/kubernetes/client-go/blob/76174b8af8cfd938018b04198595d65b48a69334/rest/request.go#L737 - if ra.res.StatusCode == 0 && ra.req.URL.Query().Get("watch") == "true" { - ra.WriteHeader(http.StatusOK) +// FlushError implements an alternative Flush that returns an error. This is +// internally used in net/http and in some standard library utilities. +func (ra *ResponseAdapter) FlushError() error { + if ra.buffered.Writer.Buffered() == 0 { + return nil + } + + return ra.buffered.Writer.Flush() +} + +// Response returns the [http.Response] generated by the [http.Handler]. +func (ra *ResponseAdapter) Response() (*http.Response, error) { + ctx := ra.req.Context() + select { + case <-ctx.Done(): + return nil, ctx.Err() + + case <-ra.ready: + res := ra.res + res.Body = ra + + return &res, nil } - ra.res.Body = ra - return ra.res } // Decorate implements [responsewriter.UserProvidedDecorator]. diff --git a/pkg/apiserver/endpoints/responsewriter/responsewriter_test.go b/pkg/apiserver/endpoints/responsewriter/responsewriter_test.go index 4c759ac6485..7cfbd015fcf 100644 --- a/pkg/apiserver/endpoints/responsewriter/responsewriter_test.go +++ b/pkg/apiserver/endpoints/responsewriter/responsewriter_test.go @@ -1,9 +1,11 @@ package responsewriter_test import ( + "context" "io" "math/rand" "net/http" + "sync" "testing" "time" @@ -23,7 +25,6 @@ func TestResponseAdapter(t *testing.T) { fn: grafanaresponsewriter.WrapHandler(http.HandlerFunc(syncHandler)), }, } - close(client.Transport.(*roundTripperFunc).ready) req, err := http.NewRequest("GET", "http://localhost/test", nil) require.NoError(t, err) @@ -40,7 +41,7 @@ func TestResponseAdapter(t *testing.T) { require.Equal(t, "OK", string(bodyBytes)) }) - t.Run("should handle synchronous write", func(t *testing.T) { + t.Run("should handle asynchronous write", func(t *testing.T) { generateRandomStrings(10) client := &http.Client{ Transport: &roundTripperFunc{ @@ -51,7 +52,6 @@ func TestResponseAdapter(t *testing.T) { fn: grafanaresponsewriter.WrapHandler(http.HandlerFunc(asyncHandler)), }, } - close(client.Transport.(*roundTripperFunc).ready) req, err := http.NewRequest("GET", "http://localhost/test?watch=true", nil) require.NoError(t, err) @@ -87,15 +87,84 @@ func TestResponseAdapter(t *testing.T) { } } }) + + t.Run("should handle asynchronous err", func(t *testing.T) { + client := &http.Client{ + Transport: &roundTripperFunc{ + ready: make(chan struct{}), + // ignore the lint error because the response is passed directly to the client, + // so the client will be responsible for closing the response body. + //nolint:bodyclose + fn: grafanaresponsewriter.WrapHandler(http.HandlerFunc(asyncErrHandler)), + }, + } + req, err := http.NewRequest("GET", "http://localhost/test?watch=true", nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + + defer func() { + err := resp.Body.Close() + require.NoError(t, err) + }() + + require.Equal(t, http.StatusInternalServerError, resp.StatusCode) + }) + + t.Run("should handle context cancellation", func(t *testing.T) { + var cancel context.CancelFunc + client := &http.Client{ + Transport: &roundTripperFunc{ + ready: make(chan struct{}), + // ignore the lint error because the response is passed directly to the client, + // so the client will be responsible for closing the response body. + //nolint:bodyclose + fn: grafanaresponsewriter.WrapHandler(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + cancel() + })), + }, + } + req, err := http.NewRequest("GET", "http://localhost/test?watch=true", nil) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(req.Context()) //nolint:govet + req = req.WithContext(ctx) + resp, err := client.Do(req) //nolint:bodyclose + require.Nil(t, resp) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + }) //nolint:govet + + t.Run("should gracefully handle concurrent WriteHeader calls", func(t *testing.T) { + t.Parallel() + + req, err := http.NewRequest(http.MethodGet, "/", nil) + require.NoError(t, err) + + const maxAttempts = 1000 + var wg sync.WaitGroup + for i := 0; i < maxAttempts; i++ { + ra := grafanaresponsewriter.NewAdapter(req) + wg.Add(2) + go func() { + defer wg.Done() + ra.WriteHeader(http.StatusOK) + }() + go func() { + defer wg.Done() + ra.WriteHeader(http.StatusOK) + }() + } + wg.Wait() + }) } func syncHandler(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("OK")) } func asyncHandler(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) for _, s := range randomStrings { time.Sleep(100 * time.Millisecond) // write the current iteration @@ -104,6 +173,13 @@ func asyncHandler(w http.ResponseWriter, r *http.Request) { } } +func asyncErrHandler(w http.ResponseWriter, _ *http.Request) { + time.Sleep(100 * time.Millisecond) + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("error")) + w.(http.Flusher).Flush() +} + var randomStrings = []string{} func generateRandomStrings(n int) {