Add keep split output ports without consumers (#5136)

* Add keep split output ports without consumers

* Fix ir reader for split outputs

* Update unit tests

* Refactoring code according to review

* Fix unit test

* Fix
This commit is contained in:
iliya mironov 2021-04-12 17:49:53 +03:00 committed by GitHub
parent 6d2740a335
commit efeff1ee3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 101 additions and 16 deletions

View File

@ -106,6 +106,8 @@ class RemoveConstToResult(BackReplacementPattern):
Transformation looks for a constant sub-graph followed by Result operation. Transformation looks for a constant sub-graph followed by Result operation.
If sub-graph is Const->data->Result -- then all three nodes are removed. If sub-graph is Const->data->Result -- then all three nodes are removed.
If there is more complex constant sub-graph -- then only Result node is removed. If there is more complex constant sub-graph -- then only Result node is removed.
If Result node has keep_output_port attribute True the node will not to be removed from graph but
the Result node will not to be saved to IR. Only port will be kept in IR.
Currently IE is unable to handle such graph so this transformation is a work around for such case. Currently IE is unable to handle such graph so this transformation is a work around for such case.
For instance, this case appears for Wide and Deep model. For instance, this case appears for Wide and Deep model.
@ -123,7 +125,8 @@ class RemoveConstToResult(BackReplacementPattern):
return dict( return dict(
nodes=[ nodes=[
('const_data', {'kind': 'data', 'value': lambda value: value is not None}), ('const_data', {'kind': 'data', 'value': lambda value: value is not None}),
('result_node', {'type': 'Result', 'kind': 'op'}), ('result_node', {'type': 'Result', 'kind': 'op',
'keep_output_port': lambda attr: not attr}),
], ],
edges=[ edges=[
('const_data', 'result_node') ('const_data', 'result_node')

View File

@ -107,7 +107,7 @@ class RemoveConstToResultReplacementTest(unittest.TestCase):
nodes = [ nodes = [
('const_node', {'type': 'Const', 'kind': 'op'}), ('const_node', {'type': 'Const', 'kind': 'op'}),
('const_data', {'kind': 'data', 'value': np.array(5)}), ('const_data', {'kind': 'data', 'value': np.array(5)}),
('result_node', {'type': 'Result', 'kind': 'op'}), ('result_node', {'type': 'Result', 'kind': 'op', 'keep_output_port': False}),
('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}), ('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
('placeholder_1_data', {'kind': 'data'}), ('placeholder_1_data', {'kind': 'data'}),
@ -150,6 +150,58 @@ class RemoveConstToResultReplacementTest(unittest.TestCase):
self.assertNotIn('const_data', graph.node) self.assertNotIn('const_data', graph.node)
self.assertNotIn('result_node', graph.node) self.assertNotIn('result_node', graph.node)
def test_only_consumer_keep_result(self):
"""Result node is only consumer of Const data node"""
nodes = [
('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
('placeholder_1_data', {'kind': 'data'}),
('placeholder_2', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
('placeholder_2_data', {'kind': 'data'}),
('shape_of', {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'}),
('shape_of_data', {'kind': 'data'}),
('split', {'type': 'Split', 'kind': 'op', 'op': 'Split'}),
('split_data1', {'kind': 'data'}),
('split_data2', {'kind': 'data'}),
('result_node1', {'type': 'Result', 'kind': 'op', 'keep_output_port': True}),
('mul', {'type': 'Mul', 'kind': 'op', 'op': 'Mul'}),
('mul_data', {'kind': 'data'}),
('result_node2', {'type': 'Result', 'kind': 'op'}),
]
edges = [
('placeholder_1', 'placeholder_1_data'),
('placeholder_2', 'placeholder_2_data'),
('placeholder_1_data', 'shape_of'),
('shape_of', 'shape_of_data'),
('shape_of_data', 'split'),
('split', 'split_data1', {'in': 0}),
('split', 'split_data2', {'in': 1}),
('split_data1', 'result_node1'),
('split_data2', 'mul'),
('placeholder_2_data', 'mul'),
('mul', 'mul_data'),
('mul_data', 'result_node2'),
]
graph = build_graph_with_attrs(
nodes_with_attrs=nodes,
edges_with_attrs=edges,
)
graph_ref = build_graph_with_attrs(
nodes_with_attrs=nodes,
edges_with_attrs=edges,
)
tested_pattern = RemoveConstToResult()
tested_pattern.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, last_node='mul_data')
self.assertTrue(flag, resp)
self.assertIn('split_data1', graph.node)
self.assertIn('split_data2', graph.node)
self.assertIn('result_node1', graph.node)
def test_two_consumers(self): def test_two_consumers(self):
"""Const data node has two consumers: Result and ReLu""" """Const data node has two consumers: Result and ReLu"""
nodes = [ nodes = [
@ -190,3 +242,34 @@ class RemoveConstToResultReplacementTest(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data') (flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data')
self.assertTrue(flag, resp) self.assertTrue(flag, resp)
self.assertNotIn('result_node', graph.node) self.assertNotIn('result_node', graph.node)
def test_two_consumers_keep_outputs(self):
"""Const data node has two consumers: Result and ReLu"""
nodes = [
('const_node', {'type': 'Const', 'kind': 'op'}),
('const_data', {'kind': 'data', 'value': np.array(5)}),
('result_node', {'type': 'Result', 'kind': 'op', 'keep_output_port': True}),
('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
('relu_1_data', {'kind': 'data'}),
]
edges = [
('const_node', 'const_data'),
('const_data', 'result_node'),
('const_data', 'relu_1'),
('relu_1', 'relu_1_data')
]
graph = build_graph_with_attrs(
nodes_with_attrs=nodes,
edges_with_attrs=edges,
)
graph_ref = build_graph_with_attrs(
nodes_with_attrs=nodes,
edges_with_attrs=edges,
)
tested_pattern = RemoveConstToResult()
tested_pattern.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data')
self.assertTrue(flag, resp)
self.assertIn('result_node', graph.node)

View File

@ -52,7 +52,7 @@ class TopKNormalizer(BackReplacementPattern):
""" """
if node.out_port(0).disconnected(): if node.out_port(0).disconnected():
output = Result(node.graph, {'name': node.name + '/Result_port_0/', output = Result(node.graph, {'name': node.name + '/Result_port_0/',
'remove_from_xml': node.has_and_set('remove_values_output')}).create_node() 'keep_output_port': node.has_and_set('remove_values_output')}).create_node()
node.out_port(0).get_connection().set_destination(output.in_port(0)) node.out_port(0).get_connection().set_destination(output.in_port(0))
if node.out_port(1).disconnected(): if node.out_port(1).disconnected():
output = Result(node.graph, {'name': node.name + '/Result_port_1/'}).create_node() output = Result(node.graph, {'name': node.name + '/Result_port_1/'}).create_node()

View File

@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from extensions.middle.TensorIteratorMerge import TensorIteratorMerge from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
from mo.graph.graph import Graph from mo.graph.graph import Graph, Node
from mo.middle.replacement import MiddleReplacementPattern from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.result import Result from mo.ops.result import Result
@ -21,23 +21,19 @@ class AddFakeOutputsToSplit(MiddleReplacementPattern):
def run_after(self): def run_after(self):
return [TensorIteratorMerge] return [TensorIteratorMerge]
@staticmethod def find_and_replace_pattern(self, graph: Graph):
def pattern(): for split_node in graph.get_op_nodes(op='Split'):
return dict( AddFakeOutputsToSplit.split_normalize_outputs(split_node)
nodes=[('op', dict(kind='op', op='Split'))],
edges=[],
)
@staticmethod @staticmethod
def replace_pattern(graph: Graph, match: dict): def split_normalize_outputs(node: Node):
node = match['op']
if node.has_valid('out_ports_count') and len(node.out_edges()) < node.out_ports_count: if node.has_valid('out_ports_count') and len(node.out_edges()) < node.out_ports_count:
for p in range(node.out_ports_count): for p in range(node.out_ports_count):
if p not in node.out_ports(): if p not in node.out_ports():
node.add_output_port(p) node.add_output_port(p)
if node.out_port(p).disconnected(): if node.out_port(p).disconnected():
res_node = Result(graph, {'name': node.name + '/Fake_output_{}/'.format(p)}).create_node() res_node = Result(node.graph, {'name': node.name + '/Fake_output_{}/'.format(p),
'keep_output_port': True}).create_node()
node.out_port(p).connect(res_node.in_port(0)) node.out_port(p).connect(res_node.in_port(0))
@ -76,4 +72,4 @@ class AddFakeOutputsToVariadicSplit(MiddleReplacementPattern):
if not node.has_valid('out_ports_count'): if not node.has_valid('out_ports_count'):
node['out_ports_count'] = len(size_splits) node['out_ports_count'] = len(size_splits)
AddFakeOutputsToSplit().replace_pattern(graph, match) AddFakeOutputsToSplit().split_normalize_outputs(node)

View File

@ -247,7 +247,7 @@ def serialize_node_attributes(
unsupported): unsupported):
# the Result op may be marked so it should not appear in the IR. For example, refer to transformation # the Result op may be marked so it should not appear in the IR. For example, refer to transformation
# model-optimizer/extensions/back/TopKNormalizer.py # model-optimizer/extensions/back/TopKNormalizer.py
if isinstance(node, Node) and node.soft_get('result' == 'Result') and node.has_and_set('remove_from_xml'): if isinstance(node, Node) and node.soft_get('type') == 'Result' and node.has_and_set('keep_output_port'):
return return
try: try:
for s in schema: for s in schema:

View File

@ -7,6 +7,7 @@ import os
import numpy as np import numpy as np
from extensions.back.TopKNormalizer import TopKNormalizer from extensions.back.TopKNormalizer import TopKNormalizer
from extensions.middle.FakeSplitOutputs import AddFakeOutputsToSplit
from extensions.ops.Cast import Cast from extensions.ops.Cast import Cast
from extensions.ops.ReduceOps import ReduceOp from extensions.ops.ReduceOps import ReduceOp
from extensions.ops.activation_ops import Activation from extensions.ops.activation_ops import Activation
@ -272,6 +273,8 @@ postprocessing_op_nodes = {
'Assign': assign_add_output_result, 'Assign': assign_add_output_result,
'TensorIterator': ti_add_edge_attrs, 'TensorIterator': ti_add_edge_attrs,
'TopK': TopKNormalizer.normalize_outputs, 'TopK': TopKNormalizer.normalize_outputs,
# Call normalize Split outputs for generated IR by ir-reader
'Split': AddFakeOutputsToSplit.split_normalize_outputs
} }