From 8ce7ef6188dc174cbab7befeb6bde3be69719926 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Tue, 13 Jan 2015 11:24:42 -0800 Subject: [PATCH] config/lang: implement identifier semantic check --- config/lang/check_identifier.go | 79 ++++++++++++++++++++++++ config/lang/check_identifier_test.go | 89 ++++++++++++++++++++++++++++ config/lang/engine.go | 18 +++++- 3 files changed, 183 insertions(+), 3 deletions(-) create mode 100644 config/lang/check_identifier.go create mode 100644 config/lang/check_identifier_test.go diff --git a/config/lang/check_identifier.go b/config/lang/check_identifier.go new file mode 100644 index 0000000000..2cc49bad10 --- /dev/null +++ b/config/lang/check_identifier.go @@ -0,0 +1,79 @@ +package lang + +import ( + "fmt" + "sync" + + "github.com/hashicorp/terraform/config/lang/ast" +) + +// IdentifierCheck is a SemanticCheck that checks that all identifiers +// resolve properly and that the right number of arguments are passed +// to functions. +type IdentifierCheck struct { + Scope *Scope + + err error + lock sync.Mutex +} + +func (c *IdentifierCheck) Visit(root ast.Node) error { + c.lock.Lock() + defer c.lock.Unlock() + defer c.reset() + root.Accept(c.visit) + return c.err +} + +func (c *IdentifierCheck) visit(raw ast.Node) { + if c.err != nil { + return + } + + switch n := raw.(type) { + case *ast.Call: + c.visitCall(n) + case *ast.VariableAccess: + c.visitVariableAccess(n) + case *ast.Concat: + // Ignore + case *ast.LiteralNode: + // Ignore + default: + c.createErr(n, fmt.Sprintf("unknown node: %#v", raw)) + } +} + +func (c *IdentifierCheck) visitCall(n *ast.Call) { + // Look up the function in the map + function, ok := c.Scope.LookupFunc(n.Func) + if !ok { + c.createErr(n, fmt.Sprintf("unknown function called: %s", n.Func)) + return + } + + // Verify the number of arguments + if len(n.Args) != len(function.ArgTypes) { + c.createErr(n, fmt.Sprintf( + "%s: expected %d arguments, got %d", + n.Func, len(function.ArgTypes), len(n.Args))) + return + } +} + +func (c *IdentifierCheck) visitVariableAccess(n *ast.VariableAccess) { + // Look up the variable in the map + if _, ok := c.Scope.LookupVar(n.Name); !ok { + c.createErr(n, fmt.Sprintf( + "unknown variable accessed: %s", n.Name)) + return + } +} + +func (c *IdentifierCheck) createErr(n ast.Node, str string) { + c.err = fmt.Errorf("%s: %s", n.Pos(), str) +} + +func (c *IdentifierCheck) reset() { + c.err = nil +} diff --git a/config/lang/check_identifier_test.go b/config/lang/check_identifier_test.go new file mode 100644 index 0000000000..52460aa52b --- /dev/null +++ b/config/lang/check_identifier_test.go @@ -0,0 +1,89 @@ +package lang + +import ( + "testing" + + "github.com/hashicorp/terraform/config/lang/ast" +) + +func TestIdentifierCheck(t *testing.T) { + cases := []struct { + Input string + Scope *Scope + Error bool + }{ + { + "foo", + &Scope{}, + false, + }, + + { + "foo ${bar} success", + &Scope{ + VarMap: map[string]Variable{ + "bar": Variable{ + Value: "baz", + Type: ast.TypeString, + }, + }, + }, + false, + }, + + { + "foo ${bar}", + &Scope{}, + true, + }, + + { + "foo ${rand()} success", + &Scope{ + FuncMap: map[string]Function{ + "rand": Function{ + ReturnType: ast.TypeString, + Callback: func([]interface{}) (interface{}, error) { + return "42", nil + }, + }, + }, + }, + false, + }, + + { + "foo ${rand()}", + &Scope{}, + true, + }, + + { + "foo ${rand(42)} ", + &Scope{ + FuncMap: map[string]Function{ + "rand": Function{ + ReturnType: ast.TypeString, + Callback: func([]interface{}) (interface{}, error) { + return "42", nil + }, + }, + }, + }, + true, + }, + } + + for _, tc := range cases { + node, err := Parse(tc.Input) + if err != nil { + t.Fatalf("Error: %s\n\nInput: %s", err, tc.Input) + } + + visitor := &IdentifierCheck{Scope: tc.Scope} + err = visitor.Visit(node) + if (err != nil) != tc.Error { + t.Fatalf("Error: %s\n\nInput: %s", err, tc.Input) + } + } +} diff --git a/config/lang/engine.go b/config/lang/engine.go index 8b4b607939..96eb27c64f 100644 --- a/config/lang/engine.go +++ b/config/lang/engine.go @@ -27,10 +27,22 @@ type SemanticChecker func(ast.Node) error // Execute executes the given ast.Node and returns its final value, its // type, and an error if one exists. func (e *Engine) Execute(root ast.Node) (interface{}, ast.Type, error) { - // Run the type checker + // Build our own semantic checks that we always run tv := &TypeVisitor{Scope: e.GlobalScope} - if err := tv.Visit(root); err != nil { - return nil, ast.TypeInvalid, err + ic := &IdentifierCheck{Scope: e.GlobalScope} + + // Build up the semantic checks for execution + checks := make( + []SemanticChecker, len(e.SemanticChecks), len(e.SemanticChecks)+2) + copy(checks, e.SemanticChecks) + checks = append(checks, ic.Visit) + checks = append(checks, tv.Visit) + + // Run the semantic checks + for _, check := range checks { + if err := check(root); err != nil { + return nil, ast.TypeInvalid, err + } } // Execute