diff --git a/helper/slowmessage/slowmessage.go b/helper/slowmessage/slowmessage.go new file mode 100644 index 0000000000..e4e1471061 --- /dev/null +++ b/helper/slowmessage/slowmessage.go @@ -0,0 +1,34 @@ +package slowmessage + +import ( + "time" +) + +// SlowFunc is the function that could be slow. Usually, you'll have to +// wrap an existing function in a lambda to make it match this type signature. +type SlowFunc func() error + +// CallbackFunc is the function that is triggered when the threshold is reached. +type CallbackFunc func() + +// Do calls sf. If threshold time has passed, cb is called. Note that this +// call will be made concurrently to sf still running. +func Do(threshold time.Duration, sf SlowFunc, cb CallbackFunc) error { + // Call the slow function + errCh := make(chan error, 1) + go func() { + errCh <- sf() + }() + + // Wait for it to complete or the threshold to pass + select { + case err := <-errCh: + return err + case <-time.After(threshold): + // Threshold reached, call the callback + cb() + } + + // Wait an indefinite amount of time for it to finally complete + return <-errCh +} diff --git a/helper/slowmessage/slowmessage_test.go b/helper/slowmessage/slowmessage_test.go new file mode 100644 index 0000000000..32658aaccc --- /dev/null +++ b/helper/slowmessage/slowmessage_test.go @@ -0,0 +1,82 @@ +package slowmessage + +import ( + "errors" + "testing" + "time" +) + +func TestDo(t *testing.T) { + var sfErr error + cbCalled := false + sfCalled := false + sfSleep := 0 * time.Second + + reset := func() { + cbCalled = false + sfCalled = false + sfErr = nil + } + sf := func() error { + sfCalled = true + time.Sleep(sfSleep) + return sfErr + } + cb := func() { cbCalled = true } + + // SF is not slow + reset() + if err := Do(10*time.Millisecond, sf, cb); err != nil { + t.Fatalf("err: %s", err) + } + + if !sfCalled { + t.Fatal("should call") + } + if cbCalled { + t.Fatal("should not call") + } + + // SF is not slow (with error) + reset() + sfErr = errors.New("error") + if err := Do(10*time.Millisecond, sf, cb); err == nil { + t.Fatalf("err: %s", err) + } + + if !sfCalled { + t.Fatal("should call") + } + if cbCalled { + t.Fatal("should not call") + } + + // SF is slow + reset() + sfSleep = 50 * time.Millisecond + if err := Do(10*time.Millisecond, sf, cb); err != nil { + t.Fatalf("err: %s", err) + } + + if !sfCalled { + t.Fatal("should call") + } + if !cbCalled { + t.Fatal("should call") + } + + // SF is slow (with error) + reset() + sfErr = errors.New("error") + sfSleep = 50 * time.Millisecond + if err := Do(10*time.Millisecond, sf, cb); err == nil { + t.Fatalf("err: %s", err) + } + + if !sfCalled { + t.Fatal("should call") + } + if !cbCalled { + t.Fatal("should call") + } +}