From efeff1ee3e48fd6874ff0ba5477215aaaec3abec Mon Sep 17 00:00:00 2001 From: iliya mironov Date: Mon, 12 Apr 2021 17:49:53 +0300 Subject: [PATCH] Add keep split output ports without consumers (#5136) * Add keep split output ports without consumers * Fix ir reader for split outputs * Update unit tests * Refactoring code according to review * Fix unit test * Fix --- .../back/SpecialNodesFinalization.py | 5 +- .../back/SpecialNodesFinalization_test.py | 85 ++++++++++++++++++- .../extensions/back/TopKNormalizer.py | 2 +- .../extensions/middle/FakeSplitOutputs.py | 20 ++--- .../mo/back/ie_ir_ver_2/emitter.py | 2 +- .../mo/utils/ir_reader/layer_to_class.py | 3 + 6 files changed, 101 insertions(+), 16 deletions(-) diff --git a/model-optimizer/extensions/back/SpecialNodesFinalization.py b/model-optimizer/extensions/back/SpecialNodesFinalization.py index c177d7873cf..915c5670d74 100644 --- a/model-optimizer/extensions/back/SpecialNodesFinalization.py +++ b/model-optimizer/extensions/back/SpecialNodesFinalization.py @@ -106,6 +106,8 @@ class RemoveConstToResult(BackReplacementPattern): Transformation looks for a constant sub-graph followed by Result operation. If sub-graph is Const->data->Result -- then all three nodes are removed. If there is more complex constant sub-graph -- then only Result node is removed. + If Result node has keep_output_port attribute True the node will not to be removed from graph but + the Result node will not to be saved to IR. Only port will be kept in IR. Currently IE is unable to handle such graph so this transformation is a work around for such case. For instance, this case appears for Wide and Deep model. @@ -123,7 +125,8 @@ class RemoveConstToResult(BackReplacementPattern): return dict( nodes=[ ('const_data', {'kind': 'data', 'value': lambda value: value is not None}), - ('result_node', {'type': 'Result', 'kind': 'op'}), + ('result_node', {'type': 'Result', 'kind': 'op', + 'keep_output_port': lambda attr: not attr}), ], edges=[ ('const_data', 'result_node') diff --git a/model-optimizer/extensions/back/SpecialNodesFinalization_test.py b/model-optimizer/extensions/back/SpecialNodesFinalization_test.py index 46523aac61f..ea6516adbc3 100644 --- a/model-optimizer/extensions/back/SpecialNodesFinalization_test.py +++ b/model-optimizer/extensions/back/SpecialNodesFinalization_test.py @@ -107,7 +107,7 @@ class RemoveConstToResultReplacementTest(unittest.TestCase): nodes = [ ('const_node', {'type': 'Const', 'kind': 'op'}), ('const_data', {'kind': 'data', 'value': np.array(5)}), - ('result_node', {'type': 'Result', 'kind': 'op'}), + ('result_node', {'type': 'Result', 'kind': 'op', 'keep_output_port': False}), ('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}), ('placeholder_1_data', {'kind': 'data'}), @@ -150,6 +150,58 @@ class RemoveConstToResultReplacementTest(unittest.TestCase): self.assertNotIn('const_data', graph.node) self.assertNotIn('result_node', graph.node) + + def test_only_consumer_keep_result(self): + """Result node is only consumer of Const data node""" + nodes = [ + ('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}), + ('placeholder_1_data', {'kind': 'data'}), + ('placeholder_2', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}), + ('placeholder_2_data', {'kind': 'data'}), + ('shape_of', {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'}), + ('shape_of_data', {'kind': 'data'}), + ('split', {'type': 'Split', 'kind': 'op', 'op': 'Split'}), + ('split_data1', {'kind': 'data'}), + ('split_data2', {'kind': 'data'}), + ('result_node1', {'type': 'Result', 'kind': 'op', 'keep_output_port': True}), + + ('mul', {'type': 'Mul', 'kind': 'op', 'op': 'Mul'}), + ('mul_data', {'kind': 'data'}), + ('result_node2', {'type': 'Result', 'kind': 'op'}), + ] + edges = [ + ('placeholder_1', 'placeholder_1_data'), + ('placeholder_2', 'placeholder_2_data'), + ('placeholder_1_data', 'shape_of'), + ('shape_of', 'shape_of_data'), + ('shape_of_data', 'split'), + ('split', 'split_data1', {'in': 0}), + ('split', 'split_data2', {'in': 1}), + + ('split_data1', 'result_node1'), + ('split_data2', 'mul'), + ('placeholder_2_data', 'mul'), + ('mul', 'mul_data'), + ('mul_data', 'result_node2'), + ] + + graph = build_graph_with_attrs( + nodes_with_attrs=nodes, + edges_with_attrs=edges, + ) + graph_ref = build_graph_with_attrs( + nodes_with_attrs=nodes, + edges_with_attrs=edges, + ) + tested_pattern = RemoveConstToResult() + tested_pattern.find_and_replace_pattern(graph) + (flag, resp) = compare_graphs(graph, graph_ref, last_node='mul_data') + self.assertTrue(flag, resp) + self.assertIn('split_data1', graph.node) + self.assertIn('split_data2', graph.node) + self.assertIn('result_node1', graph.node) + + def test_two_consumers(self): """Const data node has two consumers: Result and ReLu""" nodes = [ @@ -190,3 +242,34 @@ class RemoveConstToResultReplacementTest(unittest.TestCase): (flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data') self.assertTrue(flag, resp) self.assertNotIn('result_node', graph.node) + + + def test_two_consumers_keep_outputs(self): + """Const data node has two consumers: Result and ReLu""" + nodes = [ + ('const_node', {'type': 'Const', 'kind': 'op'}), + ('const_data', {'kind': 'data', 'value': np.array(5)}), + ('result_node', {'type': 'Result', 'kind': 'op', 'keep_output_port': True}), + ('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}), + ('relu_1_data', {'kind': 'data'}), + ] + edges = [ + ('const_node', 'const_data'), + ('const_data', 'result_node'), + ('const_data', 'relu_1'), + ('relu_1', 'relu_1_data') + ] + + graph = build_graph_with_attrs( + nodes_with_attrs=nodes, + edges_with_attrs=edges, + ) + graph_ref = build_graph_with_attrs( + nodes_with_attrs=nodes, + edges_with_attrs=edges, + ) + tested_pattern = RemoveConstToResult() + tested_pattern.find_and_replace_pattern(graph) + (flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data') + self.assertTrue(flag, resp) + self.assertIn('result_node', graph.node) diff --git a/model-optimizer/extensions/back/TopKNormalizer.py b/model-optimizer/extensions/back/TopKNormalizer.py index af756a97bfc..6fc3f33e3a9 100644 --- a/model-optimizer/extensions/back/TopKNormalizer.py +++ b/model-optimizer/extensions/back/TopKNormalizer.py @@ -52,7 +52,7 @@ class TopKNormalizer(BackReplacementPattern): """ if node.out_port(0).disconnected(): output = Result(node.graph, {'name': node.name + '/Result_port_0/', - 'remove_from_xml': node.has_and_set('remove_values_output')}).create_node() + 'keep_output_port': node.has_and_set('remove_values_output')}).create_node() node.out_port(0).get_connection().set_destination(output.in_port(0)) if node.out_port(1).disconnected(): output = Result(node.graph, {'name': node.name + '/Result_port_1/'}).create_node() diff --git a/model-optimizer/extensions/middle/FakeSplitOutputs.py b/model-optimizer/extensions/middle/FakeSplitOutputs.py index da943974f0b..b5ed4b31c94 100644 --- a/model-optimizer/extensions/middle/FakeSplitOutputs.py +++ b/model-optimizer/extensions/middle/FakeSplitOutputs.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from extensions.middle.TensorIteratorMerge import TensorIteratorMerge -from mo.graph.graph import Graph +from mo.graph.graph import Graph, Node from mo.middle.replacement import MiddleReplacementPattern from mo.ops.result import Result @@ -21,23 +21,19 @@ class AddFakeOutputsToSplit(MiddleReplacementPattern): def run_after(self): return [TensorIteratorMerge] - @staticmethod - def pattern(): - return dict( - nodes=[('op', dict(kind='op', op='Split'))], - edges=[], - ) + def find_and_replace_pattern(self, graph: Graph): + for split_node in graph.get_op_nodes(op='Split'): + AddFakeOutputsToSplit.split_normalize_outputs(split_node) @staticmethod - def replace_pattern(graph: Graph, match: dict): - node = match['op'] - + def split_normalize_outputs(node: Node): if node.has_valid('out_ports_count') and len(node.out_edges()) < node.out_ports_count: for p in range(node.out_ports_count): if p not in node.out_ports(): node.add_output_port(p) if node.out_port(p).disconnected(): - res_node = Result(graph, {'name': node.name + '/Fake_output_{}/'.format(p)}).create_node() + res_node = Result(node.graph, {'name': node.name + '/Fake_output_{}/'.format(p), + 'keep_output_port': True}).create_node() node.out_port(p).connect(res_node.in_port(0)) @@ -76,4 +72,4 @@ class AddFakeOutputsToVariadicSplit(MiddleReplacementPattern): if not node.has_valid('out_ports_count'): node['out_ports_count'] = len(size_splits) - AddFakeOutputsToSplit().replace_pattern(graph, match) + AddFakeOutputsToSplit().split_normalize_outputs(node) diff --git a/model-optimizer/mo/back/ie_ir_ver_2/emitter.py b/model-optimizer/mo/back/ie_ir_ver_2/emitter.py index 0a80a1f102d..85dce75201b 100644 --- a/model-optimizer/mo/back/ie_ir_ver_2/emitter.py +++ b/model-optimizer/mo/back/ie_ir_ver_2/emitter.py @@ -247,7 +247,7 @@ def serialize_node_attributes( unsupported): # the Result op may be marked so it should not appear in the IR. For example, refer to transformation # model-optimizer/extensions/back/TopKNormalizer.py - if isinstance(node, Node) and node.soft_get('result' == 'Result') and node.has_and_set('remove_from_xml'): + if isinstance(node, Node) and node.soft_get('type') == 'Result' and node.has_and_set('keep_output_port'): return try: for s in schema: diff --git a/model-optimizer/mo/utils/ir_reader/layer_to_class.py b/model-optimizer/mo/utils/ir_reader/layer_to_class.py index 36bae0358e7..67bc917f747 100644 --- a/model-optimizer/mo/utils/ir_reader/layer_to_class.py +++ b/model-optimizer/mo/utils/ir_reader/layer_to_class.py @@ -7,6 +7,7 @@ import os import numpy as np from extensions.back.TopKNormalizer import TopKNormalizer +from extensions.middle.FakeSplitOutputs import AddFakeOutputsToSplit from extensions.ops.Cast import Cast from extensions.ops.ReduceOps import ReduceOp from extensions.ops.activation_ops import Activation @@ -272,6 +273,8 @@ postprocessing_op_nodes = { 'Assign': assign_add_output_result, 'TensorIterator': ti_add_edge_attrs, 'TopK': TopKNormalizer.normalize_outputs, + # Call normalize Split outputs for generated IR by ir-reader + 'Split': AddFakeOutputsToSplit.split_normalize_outputs }