From 54fd742ef6314a0ff465a0c7fdf2697097d315c7 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Sun, 8 Feb 2015 17:06:17 -0800 Subject: [PATCH] dag: walk should be able to be halted --- dag/dag.go | 68 ++++++++++++++++++++++++++++++++++++------------- dag/dag_test.go | 41 ++++++++++++++++++++++++++++- 2 files changed, 91 insertions(+), 18 deletions(-) diff --git a/dag/dag.go b/dag/dag.go index 33f8571a84..9490cf98a3 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -15,7 +15,7 @@ type AcyclicGraph struct { } // WalkFunc is the callback used for walking the graph. -type WalkFunc func(Vertex) +type WalkFunc func(Vertex) error // Root returns the root of the DAG, or an error. // @@ -79,7 +79,8 @@ func (g *AcyclicGraph) Validate() error { } // Walk walks the graph, calling your callback as each node is visited. -// This will walk nodes in parallel if it can. +// This will walk nodes in parallel if it can. Because the walk is done +// in parallel, the error returned will be a multierror. func (g *AcyclicGraph) Walk(cb WalkFunc) error { // Cache the vertices since we use it multiple times vertices := g.Vertices() @@ -98,32 +99,65 @@ func (g *AcyclicGraph) Walk(cb WalkFunc) error { for _, v := range vertices { vertMap[v] = make(chan struct{}) } + + // The map of whether a vertex errored or not during the walk + var errLock sync.Mutex + var errs error + errMap := make(map[Vertex]bool) for _, v := range vertices { - // Get the list of channels to wait on - deps := g.DownEdges(v).List() + // Build our list of dependencies and the list of channels to + // wait on until we start executing for this vertex. + depsRaw := g.DownEdges(v).List() + deps := make([]Vertex, len(depsRaw)) depChs := make([]<-chan struct{}, len(deps)) - for i, dep := range deps { - depChs[i] = vertMap[dep.(Vertex)] + for i, raw := range depsRaw { + deps[i] = raw.(Vertex) + depChs[i] = vertMap[deps[i]] } - // Get our channel + // Get our channel so that we can close it when we're done ourCh := vertMap[v] - // Start the goroutine - go func(v Vertex, doneCh chan<- struct{}, chs []<-chan struct{}) { - defer close(doneCh) - defer wg.Done() - - // Wait on all our dependencies + // Start the goroutine to wait for our dependencies + readyCh := make(chan bool) + go func(deps []Vertex, chs []<-chan struct{}, readyCh chan<- bool) { + // First wait for all the dependencies for _, ch := range chs { <-ch } - // Call our callback - cb(v) - }(v, ourCh, depChs) + // Then, check the map to see if any of our dependencies failed + errLock.Lock() + defer errLock.Unlock() + for _, dep := range deps { + if errMap[dep] { + readyCh <- false + return + } + } + + readyCh <- true + }(deps, depChs, readyCh) + + // Start the goroutine that executes + go func(v Vertex, doneCh chan<- struct{}, readyCh <-chan bool) { + defer close(doneCh) + defer wg.Done() + + var err error + if ready := <-readyCh; ready { + err = cb(v) + } + + errLock.Lock() + defer errLock.Unlock() + if err != nil { + errMap[v] = true + errs = multierror.Append(errs, err) + } + }(v, ourCh, readyCh) } <-doneCh - return nil + return errs } diff --git a/dag/dag_test.go b/dag/dag_test.go index 77f548f4fe..cc9442be57 100644 --- a/dag/dag_test.go +++ b/dag/dag_test.go @@ -1,6 +1,7 @@ package dag import ( + "fmt" "reflect" "sync" "testing" @@ -96,10 +97,11 @@ func TestAcyclicGraphWalk(t *testing.T) { var visits []Vertex var lock sync.Mutex - err := g.Walk(func(v Vertex) { + err := g.Walk(func(v Vertex) error { lock.Lock() defer lock.Unlock() visits = append(visits, v) + return nil }) if err != nil { t.Fatalf("err: %s", err) @@ -117,3 +119,40 @@ func TestAcyclicGraphWalk(t *testing.T) { t.Fatalf("bad: %#v", visits) } + +func TestAcyclicGraphWalk_error(t *testing.T) { + var g AcyclicGraph + g.Add(1) + g.Add(2) + g.Add(3) + g.Connect(BasicEdge(3, 2)) + g.Connect(BasicEdge(3, 1)) + + var visits []Vertex + var lock sync.Mutex + err := g.Walk(func(v Vertex) error { + lock.Lock() + defer lock.Unlock() + + if v == 2 { + return fmt.Errorf("error") + } + + visits = append(visits, v) + return nil + }) + if err == nil { + t.Fatal("should error") + } + + expected := [][]Vertex{ + {1}, + } + for _, e := range expected { + if reflect.DeepEqual(visits, e) { + return + } + } + + t.Fatalf("bad: %#v", visits) +}