config/lang: implement identifier semantic check

This commit is contained in:
Mitchell Hashimoto 2015-01-13 11:24:42 -08:00
parent 5abbde3ac9
commit 8ce7ef6188
3 changed files with 183 additions and 3 deletions

View File

@ -0,0 +1,79 @@
package lang
import (
"fmt"
"sync"
"github.com/hashicorp/terraform/config/lang/ast"
)
// IdentifierCheck is a SemanticCheck that checks that all identifiers
// resolve properly and that the right number of arguments are passed
// to functions.
type IdentifierCheck struct {
Scope *Scope
err error
lock sync.Mutex
}
func (c *IdentifierCheck) Visit(root ast.Node) error {
c.lock.Lock()
defer c.lock.Unlock()
defer c.reset()
root.Accept(c.visit)
return c.err
}
func (c *IdentifierCheck) visit(raw ast.Node) {
if c.err != nil {
return
}
switch n := raw.(type) {
case *ast.Call:
c.visitCall(n)
case *ast.VariableAccess:
c.visitVariableAccess(n)
case *ast.Concat:
// Ignore
case *ast.LiteralNode:
// Ignore
default:
c.createErr(n, fmt.Sprintf("unknown node: %#v", raw))
}
}
func (c *IdentifierCheck) visitCall(n *ast.Call) {
// Look up the function in the map
function, ok := c.Scope.LookupFunc(n.Func)
if !ok {
c.createErr(n, fmt.Sprintf("unknown function called: %s", n.Func))
return
}
// Verify the number of arguments
if len(n.Args) != len(function.ArgTypes) {
c.createErr(n, fmt.Sprintf(
"%s: expected %d arguments, got %d",
n.Func, len(function.ArgTypes), len(n.Args)))
return
}
}
func (c *IdentifierCheck) visitVariableAccess(n *ast.VariableAccess) {
// Look up the variable in the map
if _, ok := c.Scope.LookupVar(n.Name); !ok {
c.createErr(n, fmt.Sprintf(
"unknown variable accessed: %s", n.Name))
return
}
}
func (c *IdentifierCheck) createErr(n ast.Node, str string) {
c.err = fmt.Errorf("%s: %s", n.Pos(), str)
}
func (c *IdentifierCheck) reset() {
c.err = nil
}

View File

@ -0,0 +1,89 @@
package lang
import (
"testing"
"github.com/hashicorp/terraform/config/lang/ast"
)
func TestIdentifierCheck(t *testing.T) {
cases := []struct {
Input string
Scope *Scope
Error bool
}{
{
"foo",
&Scope{},
false,
},
{
"foo ${bar} success",
&Scope{
VarMap: map[string]Variable{
"bar": Variable{
Value: "baz",
Type: ast.TypeString,
},
},
},
false,
},
{
"foo ${bar}",
&Scope{},
true,
},
{
"foo ${rand()} success",
&Scope{
FuncMap: map[string]Function{
"rand": Function{
ReturnType: ast.TypeString,
Callback: func([]interface{}) (interface{}, error) {
return "42", nil
},
},
},
},
false,
},
{
"foo ${rand()}",
&Scope{},
true,
},
{
"foo ${rand(42)} ",
&Scope{
FuncMap: map[string]Function{
"rand": Function{
ReturnType: ast.TypeString,
Callback: func([]interface{}) (interface{}, error) {
return "42", nil
},
},
},
},
true,
},
}
for _, tc := range cases {
node, err := Parse(tc.Input)
if err != nil {
t.Fatalf("Error: %s\n\nInput: %s", err, tc.Input)
}
visitor := &IdentifierCheck{Scope: tc.Scope}
err = visitor.Visit(node)
if (err != nil) != tc.Error {
t.Fatalf("Error: %s\n\nInput: %s", err, tc.Input)
}
}
}

View File

@ -27,10 +27,22 @@ type SemanticChecker func(ast.Node) error
// Execute executes the given ast.Node and returns its final value, its
// type, and an error if one exists.
func (e *Engine) Execute(root ast.Node) (interface{}, ast.Type, error) {
// Run the type checker
// Build our own semantic checks that we always run
tv := &TypeVisitor{Scope: e.GlobalScope}
if err := tv.Visit(root); err != nil {
return nil, ast.TypeInvalid, err
ic := &IdentifierCheck{Scope: e.GlobalScope}
// Build up the semantic checks for execution
checks := make(
[]SemanticChecker, len(e.SemanticChecks), len(e.SemanticChecks)+2)
copy(checks, e.SemanticChecks)
checks = append(checks, ic.Visit)
checks = append(checks, tv.Visit)
// Run the semantic checks
for _, check := range checks {
if err := check(root); err != nil {
return nil, ast.TypeInvalid, err
}
}
// Execute