mirror of
https://github.com/opentofu/opentofu.git
synced 2025-02-25 18:45:20 -06:00
config/lang: implement identifier semantic check
This commit is contained in:
parent
5abbde3ac9
commit
8ce7ef6188
79
config/lang/check_identifier.go
Normal file
79
config/lang/check_identifier.go
Normal file
@ -0,0 +1,79 @@
|
||||
package lang
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/terraform/config/lang/ast"
|
||||
)
|
||||
|
||||
// IdentifierCheck is a SemanticCheck that checks that all identifiers
|
||||
// resolve properly and that the right number of arguments are passed
|
||||
// to functions.
|
||||
type IdentifierCheck struct {
|
||||
Scope *Scope
|
||||
|
||||
err error
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (c *IdentifierCheck) Visit(root ast.Node) error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
defer c.reset()
|
||||
root.Accept(c.visit)
|
||||
return c.err
|
||||
}
|
||||
|
||||
func (c *IdentifierCheck) visit(raw ast.Node) {
|
||||
if c.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch n := raw.(type) {
|
||||
case *ast.Call:
|
||||
c.visitCall(n)
|
||||
case *ast.VariableAccess:
|
||||
c.visitVariableAccess(n)
|
||||
case *ast.Concat:
|
||||
// Ignore
|
||||
case *ast.LiteralNode:
|
||||
// Ignore
|
||||
default:
|
||||
c.createErr(n, fmt.Sprintf("unknown node: %#v", raw))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *IdentifierCheck) visitCall(n *ast.Call) {
|
||||
// Look up the function in the map
|
||||
function, ok := c.Scope.LookupFunc(n.Func)
|
||||
if !ok {
|
||||
c.createErr(n, fmt.Sprintf("unknown function called: %s", n.Func))
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the number of arguments
|
||||
if len(n.Args) != len(function.ArgTypes) {
|
||||
c.createErr(n, fmt.Sprintf(
|
||||
"%s: expected %d arguments, got %d",
|
||||
n.Func, len(function.ArgTypes), len(n.Args)))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (c *IdentifierCheck) visitVariableAccess(n *ast.VariableAccess) {
|
||||
// Look up the variable in the map
|
||||
if _, ok := c.Scope.LookupVar(n.Name); !ok {
|
||||
c.createErr(n, fmt.Sprintf(
|
||||
"unknown variable accessed: %s", n.Name))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (c *IdentifierCheck) createErr(n ast.Node, str string) {
|
||||
c.err = fmt.Errorf("%s: %s", n.Pos(), str)
|
||||
}
|
||||
|
||||
func (c *IdentifierCheck) reset() {
|
||||
c.err = nil
|
||||
}
|
89
config/lang/check_identifier_test.go
Normal file
89
config/lang/check_identifier_test.go
Normal file
@ -0,0 +1,89 @@
|
||||
package lang
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/terraform/config/lang/ast"
|
||||
)
|
||||
|
||||
func TestIdentifierCheck(t *testing.T) {
|
||||
cases := []struct {
|
||||
Input string
|
||||
Scope *Scope
|
||||
Error bool
|
||||
}{
|
||||
{
|
||||
"foo",
|
||||
&Scope{},
|
||||
false,
|
||||
},
|
||||
|
||||
{
|
||||
"foo ${bar} success",
|
||||
&Scope{
|
||||
VarMap: map[string]Variable{
|
||||
"bar": Variable{
|
||||
Value: "baz",
|
||||
Type: ast.TypeString,
|
||||
},
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
|
||||
{
|
||||
"foo ${bar}",
|
||||
&Scope{},
|
||||
true,
|
||||
},
|
||||
|
||||
{
|
||||
"foo ${rand()} success",
|
||||
&Scope{
|
||||
FuncMap: map[string]Function{
|
||||
"rand": Function{
|
||||
ReturnType: ast.TypeString,
|
||||
Callback: func([]interface{}) (interface{}, error) {
|
||||
return "42", nil
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
|
||||
{
|
||||
"foo ${rand()}",
|
||||
&Scope{},
|
||||
true,
|
||||
},
|
||||
|
||||
{
|
||||
"foo ${rand(42)} ",
|
||||
&Scope{
|
||||
FuncMap: map[string]Function{
|
||||
"rand": Function{
|
||||
ReturnType: ast.TypeString,
|
||||
Callback: func([]interface{}) (interface{}, error) {
|
||||
return "42", nil
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
node, err := Parse(tc.Input)
|
||||
if err != nil {
|
||||
t.Fatalf("Error: %s\n\nInput: %s", err, tc.Input)
|
||||
}
|
||||
|
||||
visitor := &IdentifierCheck{Scope: tc.Scope}
|
||||
err = visitor.Visit(node)
|
||||
if (err != nil) != tc.Error {
|
||||
t.Fatalf("Error: %s\n\nInput: %s", err, tc.Input)
|
||||
}
|
||||
}
|
||||
}
|
@ -27,10 +27,22 @@ type SemanticChecker func(ast.Node) error
|
||||
// Execute executes the given ast.Node and returns its final value, its
|
||||
// type, and an error if one exists.
|
||||
func (e *Engine) Execute(root ast.Node) (interface{}, ast.Type, error) {
|
||||
// Run the type checker
|
||||
// Build our own semantic checks that we always run
|
||||
tv := &TypeVisitor{Scope: e.GlobalScope}
|
||||
if err := tv.Visit(root); err != nil {
|
||||
return nil, ast.TypeInvalid, err
|
||||
ic := &IdentifierCheck{Scope: e.GlobalScope}
|
||||
|
||||
// Build up the semantic checks for execution
|
||||
checks := make(
|
||||
[]SemanticChecker, len(e.SemanticChecks), len(e.SemanticChecks)+2)
|
||||
copy(checks, e.SemanticChecks)
|
||||
checks = append(checks, ic.Visit)
|
||||
checks = append(checks, tv.Visit)
|
||||
|
||||
// Run the semantic checks
|
||||
for _, check := range checks {
|
||||
if err := check(root); err != nil {
|
||||
return nil, ast.TypeInvalid, err
|
||||
}
|
||||
}
|
||||
|
||||
// Execute
|
||||
|
Loading…
Reference in New Issue
Block a user