diff --git a/command/apply_test.go b/command/apply_test.go index 1b6e87069f..23e1f3f6ed 100644 --- a/command/apply_test.go +++ b/command/apply_test.go @@ -12,7 +12,6 @@ import ( "reflect" "strings" "sync" - "sync/atomic" "testing" "time" @@ -59,24 +58,57 @@ func TestApply(t *testing.T) { } } +// high water mark counter +type hwm struct { + sync.Mutex + val int + max int +} + +func (t *hwm) Inc() { + t.Lock() + defer t.Unlock() + t.val++ + if t.val > t.max { + t.max = t.val + } +} + +func (t *hwm) Dec() { + t.Lock() + defer t.Unlock() + t.val-- +} + +func (t *hwm) Max() int { + t.Lock() + defer t.Unlock() + return t.max +} + func TestApply_parallelism(t *testing.T) { provider := testProvider() statePath := testTempFile(t) + par := 4 + // This blocks all the appy functions. We close it when we exit so // they end quickly after this test finishes. block := make(chan struct{}) - defer close(block) + // signal how many goroutines have started + started := make(chan int, 100) + + runCount := &hwm{} - var runCount uint64 provider.ApplyFn = func( i *terraform.InstanceInfo, s *terraform.InstanceState, d *terraform.InstanceDiff) (*terraform.InstanceState, error) { // Increment so we're counting parallelism - atomic.AddUint64(&runCount, 1) - - // Block until we're done + started <- 1 + runCount.Inc() + defer runCount.Dec() + // Block here to stage up our max number of parallel instances <-block return nil, nil @@ -90,30 +122,46 @@ func TestApply_parallelism(t *testing.T) { }, } - par := uint64(5) args := []string{ "-state", statePath, fmt.Sprintf("-parallelism=%d", par), testFixturePath("parallelism"), } - // Run in a goroutine. We still try to catch any errors and - // get them on the error channel. - errCh := make(chan string, 1) + // Run in a goroutine. We can get any errors from the ui.OutputWriter + doneCh := make(chan int, 1) go func() { - if code := c.Run(args); code != 0 { - errCh <- ui.OutputWriter.String() - } + doneCh <- c.Run(args) }() + + timeout := time.After(5 * time.Second) + + // ensure things are running + for i := 0; i < par; i++ { + select { + case <-timeout: + t.Fatal("timeout waiting for all goroutines to start") + case <-started: + } + } + + // a little extra sleep, since we can't ensure all goroutines from the walk have + // really started + time.Sleep(100 * time.Millisecond) + close(block) + select { - case <-time.After(1000 * time.Millisecond): - case err := <-errCh: - t.Fatalf("err: %s", err) + case res := <-doneCh: + if res != 0 { + t.Fatal(ui.OutputWriter.String()) + } + case <-timeout: + t.Fatal("timeout waiting from Run()") } // The total in flight should equal the parallelism - if rc := atomic.LoadUint64(&runCount); rc != par { - t.Fatalf("Expected parallelism: %d, got: %d", par, rc) + if runCount.Max() != par { + t.Fatalf("Expected parallelism: %d, got: %d", par, runCount.Max()) } }