diff --git a/terraform/context.go b/terraform/context.go index ebe9475b91..6fa0ec035b 100644 --- a/terraform/context.go +++ b/terraform/context.go @@ -71,6 +71,21 @@ func (c *Context2) GraphBuilder() GraphBuilder { } } +// Refresh goes through all the resources in the state and refreshes them +// to their latest state. This will update the state that this context +// works with, along with returning it. +// +// Even in the case an error is returned, the state will be returned and +// will potentially be partially updated. +func (c *Context2) Refresh() (*State, []error) { + if _, err := c.walk(walkRefresh); err != nil { + var errs error + return nil, multierror.Append(errs, err).Errors + } + + return nil, nil +} + // Validate validates the configuration and returns any warnings or errors. func (c *Context2) Validate() ([]string, []error) { var errs error @@ -89,17 +104,25 @@ func (c *Context2) Validate() ([]string, []error) { } } - // Build the graph - graph, err := c.GraphBuilder().Build(RootModulePath) + // Walk + walker, err := c.walk(walkValidate) if err != nil { return nil, multierror.Append(errs, err).Errors } - // Walk the graph - walker := &ContextGraphWalker{Context: c, Operation: walkValidate} - graph.Walk(walker) - // Return the result rerrs := multierror.Append(errs, walker.ValidationErrors...) return walker.ValidationWarnings, rerrs.Errors } + +func (c *Context2) walk(operation walkOperation) (*ContextGraphWalker, error) { + // Build the graph + graph, err := c.GraphBuilder().Build(RootModulePath) + if err != nil { + return nil, err + } + + // Walk the graph + walker := &ContextGraphWalker{Context: c, Operation: operation} + return walker, graph.Walk(walker) +} diff --git a/terraform/context_test.go b/terraform/context_test.go index e931430f79..8ce0346702 100644 --- a/terraform/context_test.go +++ b/terraform/context_test.go @@ -2,6 +2,7 @@ package terraform import ( "fmt" + "reflect" "strings" "testing" ) @@ -565,6 +566,58 @@ func TestContext2Validate_varRefFilled(t *testing.T) { } } +func TestContext2Refresh(t *testing.T) { + p := testProvider("aws") + m := testModule(t, "refresh-basic") + ctx := testContext2(t, &ContextOpts{ + Module: m, + Providers: map[string]ResourceProviderFactory{ + "aws": testProviderFuncFixed(p), + }, + State: &State{ + Modules: []*ModuleState{ + &ModuleState{ + Path: rootModulePath, + Resources: map[string]*ResourceState{ + "aws_instance.web": &ResourceState{ + Type: "aws_instance", + Primary: &InstanceState{ + ID: "foo", + }, + }, + }, + }, + }, + }, + }) + + p.RefreshFn = nil + p.RefreshReturn = &InstanceState{ + ID: "foo", + } + + s, err := ctx.Refresh() + mod := s.RootModule() + if err != nil { + t.Fatalf("err: %s", err) + } + if !p.RefreshCalled { + t.Fatal("refresh should be called") + } + if p.RefreshState.ID != "foo" { + t.Fatalf("bad: %#v", p.RefreshState) + } + if !reflect.DeepEqual(mod.Resources["aws_instance.web"].Primary, p.RefreshReturn) { + t.Fatalf("bad: %#v %#v", mod.Resources["aws_instance.web"], p.RefreshReturn) + } + + for _, r := range mod.Resources { + if r.Type == "" { + t.Fatalf("no type: %#v", r) + } + } +} + /* func TestContextInput(t *testing.T) { input := new(MockUIInput) @@ -4158,58 +4211,6 @@ func TestContextPlan_varListErr(t *testing.T) { } } -func TestContextRefresh(t *testing.T) { - p := testProvider("aws") - m := testModule(t, "refresh-basic") - ctx := testContext(t, &ContextOpts{ - Module: m, - Providers: map[string]ResourceProviderFactory{ - "aws": testProviderFuncFixed(p), - }, - State: &State{ - Modules: []*ModuleState{ - &ModuleState{ - Path: rootModulePath, - Resources: map[string]*ResourceState{ - "aws_instance.web": &ResourceState{ - Type: "aws_instance", - Primary: &InstanceState{ - ID: "foo", - }, - }, - }, - }, - }, - }, - }) - - p.RefreshFn = nil - p.RefreshReturn = &InstanceState{ - ID: "foo", - } - - s, err := ctx.Refresh() - mod := s.RootModule() - if err != nil { - t.Fatalf("err: %s", err) - } - if !p.RefreshCalled { - t.Fatal("refresh should be called") - } - if p.RefreshState.ID != "foo" { - t.Fatalf("bad: %#v", p.RefreshState) - } - if !reflect.DeepEqual(mod.Resources["aws_instance.web"].Primary, p.RefreshReturn) { - t.Fatalf("bad: %#v %#v", mod.Resources["aws_instance.web"], p.RefreshReturn) - } - - for _, r := range mod.Resources { - if r.Type == "" { - t.Fatalf("no type: %#v", r) - } - } -} - func TestContextRefresh_delete(t *testing.T) { p := testProvider("aws") m := testModule(t, "refresh-basic") diff --git a/terraform/eval_filter_operation.go b/terraform/eval_filter_operation.go new file mode 100644 index 0000000000..c10e5918a6 --- /dev/null +++ b/terraform/eval_filter_operation.go @@ -0,0 +1,23 @@ +package terraform + +// EvalNodeOpFilterable is an interface that EvalNodes can implement +// to be filterable by the operation that is being run on Terraform. +type EvalNodeOpFilterable interface { + IncludeInOp(walkOperation) bool +} + +// EvalNodeFilterOp returns a filter function that filters nodes that +// include themselves in specific operations. +func EvalNodeFilterOp(op walkOperation) EvalNodeFilterFunc { + return func(n EvalNode) EvalNode { + include := true + if of, ok := n.(EvalNodeOpFilterable); ok { + include = of.IncludeInOp(op) + } + if include { + return n + } + + return EvalNoop{} + } +} diff --git a/terraform/eval_noop.go b/terraform/eval_noop.go new file mode 100644 index 0000000000..cfcdb1fc81 --- /dev/null +++ b/terraform/eval_noop.go @@ -0,0 +1,10 @@ +package terraform + +// EvalNoop is an EvalNode that does nothing. +type EvalNoop struct{} + +func (EvalNoop) Args() ([]EvalNode, []EvalType) { return nil, nil } +func (EvalNoop) Eval(EvalContext, []interface{}) (interface{}, error) { + return nil, nil +} +func (EvalNoop) Type() EvalType { return EvalTypeNull } diff --git a/terraform/graph.go b/terraform/graph.go index 125d5ca571..f6cce7375a 100644 --- a/terraform/graph.go +++ b/terraform/graph.go @@ -117,9 +117,8 @@ func (g *Graph) Dependable(n string) dag.Vertex { // Walk walks the graph with the given walker for callbacks. The graph // will be walked with full parallelism, so the walker should expect // to be called in concurrently. -func (g *Graph) Walk(walker GraphWalker) { - // TODO: test - g.walk(walker) +func (g *Graph) Walk(walker GraphWalker) error { + return g.walk(walker) } func (g *Graph) init() { diff --git a/terraform/graph_walk_context.go b/terraform/graph_walk_context.go index 93bebe5c40..abc8948d49 100644 --- a/terraform/graph_walk_context.go +++ b/terraform/graph_walk_context.go @@ -53,6 +53,12 @@ func (w *ContextGraphWalker) EnterGraph(g *Graph) EvalContext { } } +func (w *ContextGraphWalker) EnterEvalTree(v dag.Vertex, n EvalNode) EvalNode { + // We want to filter the evaluation tree to only include operations + // that belong in this operation. + return EvalFilter(n, EvalNodeFilterOp(w.Operation)) +} + func (w *ContextGraphWalker) ExitEvalTree( v dag.Vertex, output interface{}, err error) { if err == nil {