Add keep split output ports without consumers (#5136)

* Add keep split output ports without consumers

* Fix ir reader for split outputs

* Update unit tests

* Refactoring code according to review

* Fix unit test

* Fix
This commit is contained in:
iliya mironov 2021-04-12 17:49:53 +03:00 committed by GitHub
parent 6d2740a335
commit efeff1ee3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 101 additions and 16 deletions

View File

@ -106,6 +106,8 @@ class RemoveConstToResult(BackReplacementPattern):
Transformation looks for a constant sub-graph followed by Result operation.
If sub-graph is Const->data->Result -- then all three nodes are removed.
If there is more complex constant sub-graph -- then only Result node is removed.
If Result node has keep_output_port attribute True the node will not to be removed from graph but
the Result node will not to be saved to IR. Only port will be kept in IR.
Currently IE is unable to handle such graph so this transformation is a work around for such case.
For instance, this case appears for Wide and Deep model.
@ -123,7 +125,8 @@ class RemoveConstToResult(BackReplacementPattern):
return dict(
nodes=[
('const_data', {'kind': 'data', 'value': lambda value: value is not None}),
('result_node', {'type': 'Result', 'kind': 'op'}),
('result_node', {'type': 'Result', 'kind': 'op',
'keep_output_port': lambda attr: not attr}),
],
edges=[
('const_data', 'result_node')

View File

@ -107,7 +107,7 @@ class RemoveConstToResultReplacementTest(unittest.TestCase):
nodes = [
('const_node', {'type': 'Const', 'kind': 'op'}),
('const_data', {'kind': 'data', 'value': np.array(5)}),
('result_node', {'type': 'Result', 'kind': 'op'}),
('result_node', {'type': 'Result', 'kind': 'op', 'keep_output_port': False}),
('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
('placeholder_1_data', {'kind': 'data'}),
@ -150,6 +150,58 @@ class RemoveConstToResultReplacementTest(unittest.TestCase):
self.assertNotIn('const_data', graph.node)
self.assertNotIn('result_node', graph.node)
def test_only_consumer_keep_result(self):
"""Result node is only consumer of Const data node"""
nodes = [
('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
('placeholder_1_data', {'kind': 'data'}),
('placeholder_2', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
('placeholder_2_data', {'kind': 'data'}),
('shape_of', {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'}),
('shape_of_data', {'kind': 'data'}),
('split', {'type': 'Split', 'kind': 'op', 'op': 'Split'}),
('split_data1', {'kind': 'data'}),
('split_data2', {'kind': 'data'}),
('result_node1', {'type': 'Result', 'kind': 'op', 'keep_output_port': True}),
('mul', {'type': 'Mul', 'kind': 'op', 'op': 'Mul'}),
('mul_data', {'kind': 'data'}),
('result_node2', {'type': 'Result', 'kind': 'op'}),
]
edges = [
('placeholder_1', 'placeholder_1_data'),
('placeholder_2', 'placeholder_2_data'),
('placeholder_1_data', 'shape_of'),
('shape_of', 'shape_of_data'),
('shape_of_data', 'split'),
('split', 'split_data1', {'in': 0}),
('split', 'split_data2', {'in': 1}),
('split_data1', 'result_node1'),
('split_data2', 'mul'),
('placeholder_2_data', 'mul'),
('mul', 'mul_data'),
('mul_data', 'result_node2'),
]
graph = build_graph_with_attrs(
nodes_with_attrs=nodes,
edges_with_attrs=edges,
)
graph_ref = build_graph_with_attrs(
nodes_with_attrs=nodes,
edges_with_attrs=edges,
)
tested_pattern = RemoveConstToResult()
tested_pattern.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, last_node='mul_data')
self.assertTrue(flag, resp)
self.assertIn('split_data1', graph.node)
self.assertIn('split_data2', graph.node)
self.assertIn('result_node1', graph.node)
def test_two_consumers(self):
"""Const data node has two consumers: Result and ReLu"""
nodes = [
@ -190,3 +242,34 @@ class RemoveConstToResultReplacementTest(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data')
self.assertTrue(flag, resp)
self.assertNotIn('result_node', graph.node)
def test_two_consumers_keep_outputs(self):
"""Const data node has two consumers: Result and ReLu"""
nodes = [
('const_node', {'type': 'Const', 'kind': 'op'}),
('const_data', {'kind': 'data', 'value': np.array(5)}),
('result_node', {'type': 'Result', 'kind': 'op', 'keep_output_port': True}),
('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
('relu_1_data', {'kind': 'data'}),
]
edges = [
('const_node', 'const_data'),
('const_data', 'result_node'),
('const_data', 'relu_1'),
('relu_1', 'relu_1_data')
]
graph = build_graph_with_attrs(
nodes_with_attrs=nodes,
edges_with_attrs=edges,
)
graph_ref = build_graph_with_attrs(
nodes_with_attrs=nodes,
edges_with_attrs=edges,
)
tested_pattern = RemoveConstToResult()
tested_pattern.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data')
self.assertTrue(flag, resp)
self.assertIn('result_node', graph.node)

View File

@ -52,7 +52,7 @@ class TopKNormalizer(BackReplacementPattern):
"""
if node.out_port(0).disconnected():
output = Result(node.graph, {'name': node.name + '/Result_port_0/',
'remove_from_xml': node.has_and_set('remove_values_output')}).create_node()
'keep_output_port': node.has_and_set('remove_values_output')}).create_node()
node.out_port(0).get_connection().set_destination(output.in_port(0))
if node.out_port(1).disconnected():
output = Result(node.graph, {'name': node.name + '/Result_port_1/'}).create_node()

View File

@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
from mo.graph.graph import Graph
from mo.graph.graph import Graph, Node
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.result import Result
@ -21,23 +21,19 @@ class AddFakeOutputsToSplit(MiddleReplacementPattern):
def run_after(self):
return [TensorIteratorMerge]
@staticmethod
def pattern():
return dict(
nodes=[('op', dict(kind='op', op='Split'))],
edges=[],
)
def find_and_replace_pattern(self, graph: Graph):
for split_node in graph.get_op_nodes(op='Split'):
AddFakeOutputsToSplit.split_normalize_outputs(split_node)
@staticmethod
def replace_pattern(graph: Graph, match: dict):
node = match['op']
def split_normalize_outputs(node: Node):
if node.has_valid('out_ports_count') and len(node.out_edges()) < node.out_ports_count:
for p in range(node.out_ports_count):
if p not in node.out_ports():
node.add_output_port(p)
if node.out_port(p).disconnected():
res_node = Result(graph, {'name': node.name + '/Fake_output_{}/'.format(p)}).create_node()
res_node = Result(node.graph, {'name': node.name + '/Fake_output_{}/'.format(p),
'keep_output_port': True}).create_node()
node.out_port(p).connect(res_node.in_port(0))
@ -76,4 +72,4 @@ class AddFakeOutputsToVariadicSplit(MiddleReplacementPattern):
if not node.has_valid('out_ports_count'):
node['out_ports_count'] = len(size_splits)
AddFakeOutputsToSplit().replace_pattern(graph, match)
AddFakeOutputsToSplit().split_normalize_outputs(node)

View File

@ -247,7 +247,7 @@ def serialize_node_attributes(
unsupported):
# the Result op may be marked so it should not appear in the IR. For example, refer to transformation
# model-optimizer/extensions/back/TopKNormalizer.py
if isinstance(node, Node) and node.soft_get('result' == 'Result') and node.has_and_set('remove_from_xml'):
if isinstance(node, Node) and node.soft_get('type') == 'Result' and node.has_and_set('keep_output_port'):
return
try:
for s in schema:

View File

@ -7,6 +7,7 @@ import os
import numpy as np
from extensions.back.TopKNormalizer import TopKNormalizer
from extensions.middle.FakeSplitOutputs import AddFakeOutputsToSplit
from extensions.ops.Cast import Cast
from extensions.ops.ReduceOps import ReduceOp
from extensions.ops.activation_ops import Activation
@ -272,6 +273,8 @@ postprocessing_op_nodes = {
'Assign': assign_add_output_result,
'TensorIterator': ti_add_edge_attrs,
'TopK': TopKNormalizer.normalize_outputs,
# Call normalize Split outputs for generated IR by ir-reader
'Split': AddFakeOutputsToSplit.split_normalize_outputs
}