diff --git a/pkg/util/contextutil.go b/pkg/util/contextutil.go new file mode 100644 index 00000000000..16e3a01dd06 --- /dev/null +++ b/pkg/util/contextutil.go @@ -0,0 +1,44 @@ +package util + +import ( + "context" + "sync" + + "github.com/hashicorp/go-multierror" +) + +// this is a decorator for a regular context that contains a custom error and returns the +type contextWithCancellableReason struct { + context.Context + mtx sync.Mutex + err error +} + +func (c *contextWithCancellableReason) Err() error { + c.mtx.Lock() + defer c.mtx.Unlock() + if c.err != nil { + return multierror.Append(c.Context.Err(), c.err) + } + return c.Context.Err() +} + +type CancelCauseFunc func(error) + +// WithCancelCause creates a cancellable context that can be cancelled with a custom reason +func WithCancelCause(parent context.Context) (context.Context, CancelCauseFunc) { + ctx, cancel := context.WithCancel(parent) + errOnce := sync.Once{} + result := &contextWithCancellableReason{ + Context: ctx, + } + cancelFn := func(reason error) { + errOnce.Do(func() { + result.mtx.Lock() + result.err = reason + result.mtx.Unlock() + cancel() + }) + } + return result, cancelFn +} diff --git a/pkg/util/contextutil_test.go b/pkg/util/contextutil_test.go new file mode 100644 index 00000000000..e0d6e9736d3 --- /dev/null +++ b/pkg/util/contextutil_test.go @@ -0,0 +1,46 @@ +package util + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWithCancelWithReason(t *testing.T) { + t.Run("should add custom reason to the standard error", func(t *testing.T) { + expected := errors.New("test-err") + ctx, fn := WithCancelCause(context.Background()) + fn(expected) + select { + case <-ctx.Done(): + default: + require.Fail(t, "the context was not cancelled") + } + require.ErrorIs(t, ctx.Err(), expected) + require.ErrorIs(t, ctx.Err(), context.Canceled) + }) + + t.Run("should return only the first reason if called multiple times", func(t *testing.T) { + expected := errors.New("test-err") + ctx, fn := WithCancelCause(context.Background()) + fn(expected) + fn(errors.New("other error")) + require.ErrorIs(t, ctx.Err(), expected) + }) + + t.Run("should return only the first reason if called multiple times", func(t *testing.T) { + expected := errors.New("test-err") + ctx, fn := WithCancelCause(context.Background()) + fn(expected) + fn(errors.New("other error")) + require.ErrorIs(t, ctx.Err(), expected) + }) + + t.Run("should return context.Canceled if no reason provided", func(t *testing.T) { + ctx, fn := WithCancelCause(context.Background()) + fn(nil) + require.Equal(t, ctx.Err(), context.Canceled) + }) +}