[ MO ] Reinfer shape sub-graphs once (#1908)
* [ MO ] Reinfer shape sub-graphs once * feedback * feedback
This commit is contained in:
committed by
GitHub
parent
719797326b
commit
e6c371ae2e
@@ -24,6 +24,7 @@ from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForC
|
||||
from extensions.middle.pass_separator import PostMiddleStart
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.graph.port import Port
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.utils.error import Error
|
||||
|
||||
@@ -148,5 +149,12 @@ class ApplyPermutation(MiddleReplacementPattern):
|
||||
shape_ops = graph.get_op_nodes(op='ShapeOf')
|
||||
for shape in shape_ops:
|
||||
shape.infer(shape)
|
||||
|
||||
def reinfer_once(in_port: Port):
|
||||
node = in_port.node
|
||||
if not node.soft_get('reinferred', False):
|
||||
node.infer(node)
|
||||
node['reinferred'] = True
|
||||
|
||||
LayoutChangeForConstantShapePaths().find_shape_subgraph_endpoints(
|
||||
[shape.out_port(0) for shape in shape_ops], None, lambda in_port: in_port.node.infer(in_port.node))
|
||||
out_ports=[shape.out_port(0) for shape in shape_ops], action=reinfer_once)
|
||||
|
||||
Reference in New Issue
Block a user