[ MO ] Reinfer shape sub-graphs once (#1908)

* [ MO ] Reinfer shape sub-graphs once

* feedback

* feedback
This commit is contained in:
Evgenya Stepyreva
2020-08-24 14:30:41 +03:00
committed by GitHub
parent 719797326b
commit e6c371ae2e

View File

@@ -24,6 +24,7 @@ from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForC
from extensions.middle.pass_separator import PostMiddleStart
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph, Node
from mo.graph.port import Port
from mo.middle.replacement import MiddleReplacementPattern
from mo.utils.error import Error
@@ -148,5 +149,12 @@ class ApplyPermutation(MiddleReplacementPattern):
shape_ops = graph.get_op_nodes(op='ShapeOf')
for shape in shape_ops:
shape.infer(shape)
def reinfer_once(in_port: Port):
node = in_port.node
if not node.soft_get('reinferred', False):
node.infer(node)
node['reinferred'] = True
LayoutChangeForConstantShapePaths().find_shape_subgraph_endpoints(
[shape.out_port(0) for shape in shape_ops], None, lambda in_port: in_port.node.infer(in_port.node))
out_ports=[shape.out_port(0) for shape in shape_ops], action=reinfer_once)