Compare ModuleInstance to Module without allocating, and similar (#2261)

Signed-off-by: Martin Atkins <mart@degeneration.co.uk>
This commit is contained in:
Martin Atkins 2024-12-12 09:47:57 -08:00 committed by GitHub
parent 73e4a657ae
commit 27ab52fd03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 201 additions and 21 deletions

View File

@ -9,6 +9,8 @@ import (
"strings" "strings"
"github.com/hashicorp/hcl/v2" "github.com/hashicorp/hcl/v2"
"github.com/hashicorp/hcl/v2/hclsyntax"
"github.com/opentofu/opentofu/internal/tfdiags" "github.com/opentofu/opentofu/internal/tfdiags"
) )
@ -196,6 +198,22 @@ func ParseModule(traversal hcl.Traversal) (Module, tfdiags.Diagnostics) {
return mod, diags return mod, diags
} }
// ParseModuleStr is a helper wrapper around [ParseModule] that first tries
// to parse the given string as HCL traversal syntax.
func ParseModuleStr(str string) (Module, tfdiags.Diagnostics) {
var diags tfdiags.Diagnostics
traversal, parseDiags := hclsyntax.ParseTraversalAbs([]byte(str), "", hcl.Pos{Line: 1, Column: 1})
diags = diags.Append(parseDiags)
if parseDiags.HasErrors() {
return nil, diags
}
addr, addrDiags := ParseModule(traversal)
diags = diags.Append(addrDiags)
return addr, diags
}
// parseModulePrefix parses a module address from the given traversal, // parseModulePrefix parses a module address from the given traversal,
// returning the module address and the remaining traversal. // returning the module address and the remaining traversal.
// For example, if the input traversal is ["module","a","module","b", // For example, if the input traversal is ["module","a","module","b",

View File

@ -471,6 +471,42 @@ func (m ModuleInstance) Module() Module {
return ret return ret
} }
// HasSameModule returns true if calling [ModuleInstance.Module] on both
// the receiver and the other given address would produce equal [Module]
// addresses.
//
// This is here only as an optimization to avoid the overhead of constructing
// two [Module] values just to compare them and then throw them away.
func (m ModuleInstance) HasSameModule(other ModuleInstance) bool {
if len(m) != len(other) {
return false
}
for i := range m {
if m[i].Name != other[i].Name {
return false
}
}
return true
}
// HasSameModule returns true if calling [ModuleInstance.Module] on the
// receiver would return a [Module] address equal to the one given as
// an argument.
//
// This is here only as an optimization to avoid the overhead of constructing
// a [Module] value from the reciever just to compare it and then throw it away.
func (m ModuleInstance) IsForModule(module Module) bool {
if len(m) != len(module) {
return false
}
for i := range m {
if m[i].Name != module[i] {
return false
}
}
return true
}
func (m ModuleInstance) AddrType() TargetableAddrType { func (m ModuleInstance) AddrType() TargetableAddrType {
return ModuleInstanceAddrType return ModuleInstanceAddrType
} }

View File

@ -83,6 +83,126 @@ func TestModuleInstanceEqual_false(t *testing.T) {
} }
} }
func TestHasSameModule(t *testing.T) {
tests := []struct {
a string
b string
wantSame bool
}{
{
"module.foo",
"module.bar",
false,
},
{
"module.foo",
"module.foo.module.bar",
false,
},
{
"module.foo[1]",
"module.bar[1]",
false,
},
{
`module.foo[1]`,
`module.foo["1"]`,
true,
},
{
"module.foo.module.bar",
"module.foo[1].module.bar",
true,
},
{
`module.foo.module.bar`,
`module.foo["a"].module.bar`,
true,
},
}
for _, test := range tests {
t.Run(fmt.Sprintf("%#v.HasSameModule(%#v)", test.a, test.b), func(t *testing.T) {
a, diags := ParseModuleInstanceStr(test.a)
if len(diags) > 0 {
t.Fatalf("invalid module instance address %s: %s", test.a, diags.Err())
}
b, diags := ParseModuleInstanceStr(test.b)
if len(diags) > 0 {
t.Fatalf("invalid module instance address %s: %s", test.b, diags.Err())
}
// "HasSameModule" is commutative, so we'll test it both ways at once
gotAB := a.HasSameModule(b)
gotBA := b.HasSameModule(a)
if gotAB != test.wantSame {
t.Errorf("wrong result\n1st: %s\n2nd: %s\ngot: %t\nwant: %t", a, b, gotAB, test.wantSame)
}
if gotBA != test.wantSame {
t.Errorf("wrong result\n1st: %s\n2nd: %s\ngot: %t\nwant: %t", b, a, gotBA, test.wantSame)
}
})
}
}
func TestIsForModule(t *testing.T) {
tests := []struct {
inst string
mod string
want bool
}{
{
"module.foo",
"module.bar",
false,
},
{
"module.foo",
"module.foo.module.bar",
false,
},
{
"module.foo[1]",
"module.bar",
false,
},
{
`module.foo[1]`,
`module.foo`,
true,
},
{
"module.foo[1].module.bar",
"module.foo.module.bar",
true,
},
{
`module.foo["a"].module.bar`,
`module.foo.module.bar`,
true,
},
}
for _, test := range tests {
t.Run(fmt.Sprintf("%#v.IsForModule(%#v)", test.inst, test.mod), func(t *testing.T) {
inst, diags := ParseModuleInstanceStr(test.inst)
if len(diags) > 0 {
t.Fatalf("invalid module instance address %s: %s", test.inst, diags.Err())
}
mod, diags := ParseModuleStr(test.mod)
if len(diags) > 0 {
t.Fatalf("invalid module address %s: %s", test.mod, diags.Err())
}
got := inst.IsForModule(mod)
if got != test.want {
t.Errorf("wrong result\ninstance: %s\nmodule: %s\ngot: %t\nwant: %t", inst, mod, got, test.want)
}
})
}
}
func BenchmarkStringShort(b *testing.B) { func BenchmarkStringShort(b *testing.B) {
addr, _ := ParseModuleInstanceStr(`module.foo`) addr, _ := ParseModuleInstanceStr(`module.foo`)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {

View File

@ -176,3 +176,11 @@ func TestParseModule(t *testing.T) {
}) })
} }
} }
func mustParseModuleStr(str string) Module {
m, diags := ParseModuleStr(str)
if diags.HasErrors() {
panic(diags.ErrWithWarnings())
}
return m
}

View File

@ -727,17 +727,17 @@ func (from *MoveEndpointInModule) IsModuleReIndex(to *MoveEndpointInModule) bool
// the module call. We're not actually comparing indexes, so the // the module call. We're not actually comparing indexes, so the
// instance doesn't matter. // instance doesn't matter.
callAddr := f.Instance(NoKey).Module() callAddr := f.Instance(NoKey).Module()
return callAddr.Equal(t.Module()) return t.IsForModule(callAddr)
} }
case ModuleInstance: case ModuleInstance:
switch t := to.relSubject.(type) { switch t := to.relSubject.(type) {
case AbsModuleCall: case AbsModuleCall:
callAddr := t.Instance(NoKey).Module() callAddr := t.Instance(NoKey).Module()
return callAddr.Equal(f.Module()) return f.IsForModule(callAddr)
case ModuleInstance: case ModuleInstance:
return t.Module().Equal(f.Module()) return t.HasSameModule(f)
} }
} }

View File

@ -1375,7 +1375,7 @@ func TestSelectsModule(t *testing.T) {
}, },
{ {
Endpoint: &MoveEndpointInModule{ Endpoint: &MoveEndpointInModule{
module: mustParseModuleInstanceStr("module.foo").Module(), module: mustParseModuleStr("module.foo"),
relSubject: AbsModuleCall{ relSubject: AbsModuleCall{
Module: mustParseModuleInstanceStr("module.bar[2]"), Module: mustParseModuleInstanceStr("module.bar[2]"),
Call: ModuleCall{Name: "baz"}, Call: ModuleCall{Name: "baz"},
@ -1386,7 +1386,7 @@ func TestSelectsModule(t *testing.T) {
}, },
{ {
Endpoint: &MoveEndpointInModule{ Endpoint: &MoveEndpointInModule{
module: mustParseModuleInstanceStr("module.foo").Module(), module: mustParseModuleStr("module.foo"),
relSubject: AbsModuleCall{ relSubject: AbsModuleCall{
Module: mustParseModuleInstanceStr("module.bar[2]"), Module: mustParseModuleInstanceStr("module.bar[2]"),
Call: ModuleCall{Name: "baz"}, Call: ModuleCall{Name: "baz"},
@ -1407,7 +1407,7 @@ func TestSelectsModule(t *testing.T) {
}, },
{ {
Endpoint: &MoveEndpointInModule{ Endpoint: &MoveEndpointInModule{
module: mustParseModuleInstanceStr("module.foo").Module(), module: mustParseModuleStr("module.foo"),
relSubject: mustParseAbsResourceInstanceStr(`module.bar.resource.name["key"]`), relSubject: mustParseAbsResourceInstanceStr(`module.bar.resource.name["key"]`),
}, },
Addr: mustParseModuleInstanceStr(`module.foo[1].module.bar`), Addr: mustParseModuleInstanceStr(`module.foo[1].module.bar`),
@ -1429,7 +1429,7 @@ func TestSelectsModule(t *testing.T) {
}, },
{ {
Endpoint: &MoveEndpointInModule{ Endpoint: &MoveEndpointInModule{
module: mustParseModuleInstanceStr("module.nope").Module(), module: mustParseModuleStr("module.nope"),
relSubject: mustParseAbsResourceInstanceStr(`module.bar.resource.name["key"]`), relSubject: mustParseAbsResourceInstanceStr(`module.bar.resource.name["key"]`),
}, },
Addr: mustParseModuleInstanceStr(`module.foo[1].module.bar`), Addr: mustParseModuleInstanceStr(`module.foo[1].module.bar`),

View File

@ -94,7 +94,7 @@ func (s *State) Module(addr addrs.ModuleInstance) *Module {
func (s *State) ModuleInstances(addr addrs.Module) []*Module { func (s *State) ModuleInstances(addr addrs.Module) []*Module {
var ms []*Module var ms []*Module
for _, m := range s.Modules { for _, m := range s.Modules {
if m.Addr.Module().Equal(addr) { if m.Addr.IsForModule(addr) {
ms = append(ms, m) ms = append(ms, m)
} }
} }

View File

@ -212,7 +212,7 @@ func (ctx *BuiltinEvalContext) CloseProvider(addr addrs.AbsProviderConfig) error
func (ctx *BuiltinEvalContext) ConfigureProvider(addr addrs.AbsProviderConfig, providerKey addrs.InstanceKey, cfg cty.Value) tfdiags.Diagnostics { func (ctx *BuiltinEvalContext) ConfigureProvider(addr addrs.AbsProviderConfig, providerKey addrs.InstanceKey, cfg cty.Value) tfdiags.Diagnostics {
var diags tfdiags.Diagnostics var diags tfdiags.Diagnostics
if !addr.Module.Equal(ctx.Path().Module()) { if !ctx.Path().IsForModule(addr.Module) {
// This indicates incorrect use of ConfigureProvider: it should be used // This indicates incorrect use of ConfigureProvider: it should be used
// only from the module that the provider configuration belongs to. // only from the module that the provider configuration belongs to.
panic(fmt.Sprintf("%s configured by wrong module %s", addr, ctx.Path())) panic(fmt.Sprintf("%s configured by wrong module %s", addr, ctx.Path()))
@ -237,7 +237,7 @@ func (ctx *BuiltinEvalContext) ProviderInput(pc addrs.AbsProviderConfig) map[str
ctx.ProviderLock.Lock() ctx.ProviderLock.Lock()
defer ctx.ProviderLock.Unlock() defer ctx.ProviderLock.Unlock()
if !pc.Module.Equal(ctx.Path().Module()) { if !ctx.Path().IsForModule(pc.Module) {
// This indicates incorrect use of InitProvider: it should be used // This indicates incorrect use of InitProvider: it should be used
// only from the module that the provider configuration belongs to. // only from the module that the provider configuration belongs to.
panic(fmt.Sprintf("%s initialized by wrong module %s", pc, ctx.Path())) panic(fmt.Sprintf("%s initialized by wrong module %s", pc, ctx.Path()))
@ -346,7 +346,7 @@ func (ctx *BuiltinEvalContext) EvaluateReplaceTriggeredBy(expr hcl.Expression, r
} }
// Do some validation to make sure we are expecting a change at all // Do some validation to make sure we are expecting a change at all
cfg := ctx.Evaluator.Config.Descendent(ctx.Path().Module()) cfg := ctx.Evaluator.Config.DescendentForInstance(ctx.Path())
resCfg := cfg.Module.ResourceByAddr(resourceAddr) resCfg := cfg.Module.ResourceByAddr(resourceAddr)
if resCfg == nil { if resCfg == nil {
diags = diags.Append(&hcl.Diagnostic{ diags = diags.Append(&hcl.Diagnostic{
@ -470,7 +470,7 @@ func (ctx *BuiltinEvalContext) EvaluationScope(self addrs.Referenceable, source
var providerKey addrs.InstanceKey var providerKey addrs.InstanceKey
if providedBy.KeyExpression != nil && ctx.Evaluator.Operation != walkValidate { if providedBy.KeyExpression != nil && ctx.Evaluator.Operation != walkValidate {
moduleInstanceForKey := ctx.PathValue[:len(providedBy.KeyModule)] moduleInstanceForKey := ctx.PathValue[:len(providedBy.KeyModule)]
if !moduleInstanceForKey.Module().Equal(providedBy.KeyModule) { if !moduleInstanceForKey.IsForModule(providedBy.KeyModule) {
panic(fmt.Sprintf("Invalid module key expression location %s in function %s", providedBy.KeyModule, pf.String())) panic(fmt.Sprintf("Invalid module key expression location %s in function %s", providedBy.KeyModule, pf.String()))
} }

View File

@ -632,9 +632,7 @@ func graphNodesAreResourceInstancesInDifferentInstancesOfSameModule(a, b dag.Ver
} }
aModInst := aRI.ResourceInstanceAddr().Module aModInst := aRI.ResourceInstanceAddr().Module
bModInst := bRI.ResourceInstanceAddr().Module bModInst := bRI.ResourceInstanceAddr().Module
aMod := aModInst.Module() if !aModInst.HasSameModule(bModInst) {
bMod := bModInst.Module()
if !aMod.Equal(bMod) {
return false return false
} }
return !aModInst.Equal(bModInst) return !aModInst.Equal(bModInst)

View File

@ -146,7 +146,7 @@ func (n *NodeAbstractResourceInstance) resolveProvider(ctx EvalContext, hasExpan
} else { } else {
// Resolved from module instance // Resolved from module instance
moduleInstanceForKey := n.Addr.Module[:len(n.ResolvedProvider.KeyModule)] moduleInstanceForKey := n.Addr.Module[:len(n.ResolvedProvider.KeyModule)]
if !moduleInstanceForKey.Module().Equal(n.ResolvedProvider.KeyModule) { if !moduleInstanceForKey.IsForModule(n.ResolvedProvider.KeyModule) {
panic(fmt.Sprintf("Invalid module key expression location %s in resource %s", n.ResolvedProvider.KeyModule, n.Addr)) panic(fmt.Sprintf("Invalid module key expression location %s in resource %s", n.ResolvedProvider.KeyModule, n.Addr))
} }
@ -1974,9 +1974,8 @@ func (n *NodeAbstractResourceInstance) dependenciesHavePendingChanges(ctx EvalCo
for _, change := range changes.GetChangesForConfigResource(d) { for _, change := range changes.GetChangesForConfigResource(d) {
changeModInst := change.Addr.Module changeModInst := change.Addr.Module
changeMod := changeModInst.Module()
if changeMod.Equal(nMod) && !changeModInst.Equal(nModInst) { if changeModInst.IsForModule(nMod) && !changeModInst.Equal(nModInst) {
// Dependencies are tracked by configuration address, which // Dependencies are tracked by configuration address, which
// means we may have changes from other instances of parent // means we may have changes from other instances of parent
// modules. The actual reference can only take effect within // modules. The actual reference can only take effect within

View File

@ -249,7 +249,7 @@ func initMockEvalContext(resourceAddrs string, deposedKey states.DeposedKey) (*M
state := states.NewState() state := states.NewState()
absResource := mustResourceInstanceAddr(resourceAddrs) absResource := mustResourceInstanceAddr(resourceAddrs)
if !absResource.Module.Module().Equal(addrs.RootModule) { if !absResource.Module.IsRoot() {
state.EnsureModule(addrs.RootModuleInstance.Child(absResource.Module[0].Name, absResource.Module[0].InstanceKey)) state.EnsureModule(addrs.RootModuleInstance.Child(absResource.Module[0].Name, absResource.Module[0].InstanceKey))
} }

View File

@ -92,7 +92,7 @@ func TestNodeResourcePlanOrphan_Execute(t *testing.T) {
state := states.NewState() state := states.NewState()
absResource := mustResourceInstanceAddr(test.nodeAddress) absResource := mustResourceInstanceAddr(test.nodeAddress)
if !absResource.Module.Module().Equal(addrs.RootModule) { if !absResource.Module.IsRoot() {
state.EnsureModule(addrs.RootModuleInstance.Child(absResource.Module[0].Name, absResource.Module[0].InstanceKey)) state.EnsureModule(addrs.RootModuleInstance.Child(absResource.Module[0].Name, absResource.Module[0].InstanceKey))
} }

View File

@ -33,7 +33,8 @@ func (t *RemovedModuleTransformer) Transform(g *Graph) error {
if cc != nil { if cc != nil {
continue continue
} }
removed[m.Addr.Module().String()] = m.Addr.Module() mod := m.Addr.Module()
removed[mod.String()] = mod
log.Printf("[DEBUG] %s is no longer in configuration\n", m.Addr) log.Printf("[DEBUG] %s is no longer in configuration\n", m.Addr)
} }