Add keep split output ports without consumers (#5136)
* Add keep split output ports without consumers * Fix ir reader for split outputs * Update unit tests * Refactoring code according to review * Fix unit test * Fix
This commit is contained in:
parent
6d2740a335
commit
efeff1ee3e
@ -106,6 +106,8 @@ class RemoveConstToResult(BackReplacementPattern):
|
||||
Transformation looks for a constant sub-graph followed by Result operation.
|
||||
If sub-graph is Const->data->Result -- then all three nodes are removed.
|
||||
If there is more complex constant sub-graph -- then only Result node is removed.
|
||||
If Result node has keep_output_port attribute True the node will not to be removed from graph but
|
||||
the Result node will not to be saved to IR. Only port will be kept in IR.
|
||||
|
||||
Currently IE is unable to handle such graph so this transformation is a work around for such case.
|
||||
For instance, this case appears for Wide and Deep model.
|
||||
@ -123,7 +125,8 @@ class RemoveConstToResult(BackReplacementPattern):
|
||||
return dict(
|
||||
nodes=[
|
||||
('const_data', {'kind': 'data', 'value': lambda value: value is not None}),
|
||||
('result_node', {'type': 'Result', 'kind': 'op'}),
|
||||
('result_node', {'type': 'Result', 'kind': 'op',
|
||||
'keep_output_port': lambda attr: not attr}),
|
||||
],
|
||||
edges=[
|
||||
('const_data', 'result_node')
|
||||
|
@ -107,7 +107,7 @@ class RemoveConstToResultReplacementTest(unittest.TestCase):
|
||||
nodes = [
|
||||
('const_node', {'type': 'Const', 'kind': 'op'}),
|
||||
('const_data', {'kind': 'data', 'value': np.array(5)}),
|
||||
('result_node', {'type': 'Result', 'kind': 'op'}),
|
||||
('result_node', {'type': 'Result', 'kind': 'op', 'keep_output_port': False}),
|
||||
|
||||
('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
|
||||
('placeholder_1_data', {'kind': 'data'}),
|
||||
@ -150,6 +150,58 @@ class RemoveConstToResultReplacementTest(unittest.TestCase):
|
||||
self.assertNotIn('const_data', graph.node)
|
||||
self.assertNotIn('result_node', graph.node)
|
||||
|
||||
|
||||
def test_only_consumer_keep_result(self):
|
||||
"""Result node is only consumer of Const data node"""
|
||||
nodes = [
|
||||
('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
|
||||
('placeholder_1_data', {'kind': 'data'}),
|
||||
('placeholder_2', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
|
||||
('placeholder_2_data', {'kind': 'data'}),
|
||||
('shape_of', {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'}),
|
||||
('shape_of_data', {'kind': 'data'}),
|
||||
('split', {'type': 'Split', 'kind': 'op', 'op': 'Split'}),
|
||||
('split_data1', {'kind': 'data'}),
|
||||
('split_data2', {'kind': 'data'}),
|
||||
('result_node1', {'type': 'Result', 'kind': 'op', 'keep_output_port': True}),
|
||||
|
||||
('mul', {'type': 'Mul', 'kind': 'op', 'op': 'Mul'}),
|
||||
('mul_data', {'kind': 'data'}),
|
||||
('result_node2', {'type': 'Result', 'kind': 'op'}),
|
||||
]
|
||||
edges = [
|
||||
('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_2', 'placeholder_2_data'),
|
||||
('placeholder_1_data', 'shape_of'),
|
||||
('shape_of', 'shape_of_data'),
|
||||
('shape_of_data', 'split'),
|
||||
('split', 'split_data1', {'in': 0}),
|
||||
('split', 'split_data2', {'in': 1}),
|
||||
|
||||
('split_data1', 'result_node1'),
|
||||
('split_data2', 'mul'),
|
||||
('placeholder_2_data', 'mul'),
|
||||
('mul', 'mul_data'),
|
||||
('mul_data', 'result_node2'),
|
||||
]
|
||||
|
||||
graph = build_graph_with_attrs(
|
||||
nodes_with_attrs=nodes,
|
||||
edges_with_attrs=edges,
|
||||
)
|
||||
graph_ref = build_graph_with_attrs(
|
||||
nodes_with_attrs=nodes,
|
||||
edges_with_attrs=edges,
|
||||
)
|
||||
tested_pattern = RemoveConstToResult()
|
||||
tested_pattern.find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, last_node='mul_data')
|
||||
self.assertTrue(flag, resp)
|
||||
self.assertIn('split_data1', graph.node)
|
||||
self.assertIn('split_data2', graph.node)
|
||||
self.assertIn('result_node1', graph.node)
|
||||
|
||||
|
||||
def test_two_consumers(self):
|
||||
"""Const data node has two consumers: Result and ReLu"""
|
||||
nodes = [
|
||||
@ -190,3 +242,34 @@ class RemoveConstToResultReplacementTest(unittest.TestCase):
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data')
|
||||
self.assertTrue(flag, resp)
|
||||
self.assertNotIn('result_node', graph.node)
|
||||
|
||||
|
||||
def test_two_consumers_keep_outputs(self):
|
||||
"""Const data node has two consumers: Result and ReLu"""
|
||||
nodes = [
|
||||
('const_node', {'type': 'Const', 'kind': 'op'}),
|
||||
('const_data', {'kind': 'data', 'value': np.array(5)}),
|
||||
('result_node', {'type': 'Result', 'kind': 'op', 'keep_output_port': True}),
|
||||
('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
|
||||
('relu_1_data', {'kind': 'data'}),
|
||||
]
|
||||
edges = [
|
||||
('const_node', 'const_data'),
|
||||
('const_data', 'result_node'),
|
||||
('const_data', 'relu_1'),
|
||||
('relu_1', 'relu_1_data')
|
||||
]
|
||||
|
||||
graph = build_graph_with_attrs(
|
||||
nodes_with_attrs=nodes,
|
||||
edges_with_attrs=edges,
|
||||
)
|
||||
graph_ref = build_graph_with_attrs(
|
||||
nodes_with_attrs=nodes,
|
||||
edges_with_attrs=edges,
|
||||
)
|
||||
tested_pattern = RemoveConstToResult()
|
||||
tested_pattern.find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data')
|
||||
self.assertTrue(flag, resp)
|
||||
self.assertIn('result_node', graph.node)
|
||||
|
@ -52,7 +52,7 @@ class TopKNormalizer(BackReplacementPattern):
|
||||
"""
|
||||
if node.out_port(0).disconnected():
|
||||
output = Result(node.graph, {'name': node.name + '/Result_port_0/',
|
||||
'remove_from_xml': node.has_and_set('remove_values_output')}).create_node()
|
||||
'keep_output_port': node.has_and_set('remove_values_output')}).create_node()
|
||||
node.out_port(0).get_connection().set_destination(output.in_port(0))
|
||||
if node.out_port(1).disconnected():
|
||||
output = Result(node.graph, {'name': node.name + '/Result_port_1/'}).create_node()
|
||||
|
@ -2,7 +2,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
|
||||
from mo.graph.graph import Graph
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.result import Result
|
||||
|
||||
@ -21,23 +21,19 @@ class AddFakeOutputsToSplit(MiddleReplacementPattern):
|
||||
def run_after(self):
|
||||
return [TensorIteratorMerge]
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[('op', dict(kind='op', op='Split'))],
|
||||
edges=[],
|
||||
)
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for split_node in graph.get_op_nodes(op='Split'):
|
||||
AddFakeOutputsToSplit.split_normalize_outputs(split_node)
|
||||
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: dict):
|
||||
node = match['op']
|
||||
|
||||
def split_normalize_outputs(node: Node):
|
||||
if node.has_valid('out_ports_count') and len(node.out_edges()) < node.out_ports_count:
|
||||
for p in range(node.out_ports_count):
|
||||
if p not in node.out_ports():
|
||||
node.add_output_port(p)
|
||||
if node.out_port(p).disconnected():
|
||||
res_node = Result(graph, {'name': node.name + '/Fake_output_{}/'.format(p)}).create_node()
|
||||
res_node = Result(node.graph, {'name': node.name + '/Fake_output_{}/'.format(p),
|
||||
'keep_output_port': True}).create_node()
|
||||
node.out_port(p).connect(res_node.in_port(0))
|
||||
|
||||
|
||||
@ -76,4 +72,4 @@ class AddFakeOutputsToVariadicSplit(MiddleReplacementPattern):
|
||||
if not node.has_valid('out_ports_count'):
|
||||
node['out_ports_count'] = len(size_splits)
|
||||
|
||||
AddFakeOutputsToSplit().replace_pattern(graph, match)
|
||||
AddFakeOutputsToSplit().split_normalize_outputs(node)
|
||||
|
@ -247,7 +247,7 @@ def serialize_node_attributes(
|
||||
unsupported):
|
||||
# the Result op may be marked so it should not appear in the IR. For example, refer to transformation
|
||||
# model-optimizer/extensions/back/TopKNormalizer.py
|
||||
if isinstance(node, Node) and node.soft_get('result' == 'Result') and node.has_and_set('remove_from_xml'):
|
||||
if isinstance(node, Node) and node.soft_get('type') == 'Result' and node.has_and_set('keep_output_port'):
|
||||
return
|
||||
try:
|
||||
for s in schema:
|
||||
|
@ -7,6 +7,7 @@ import os
|
||||
import numpy as np
|
||||
|
||||
from extensions.back.TopKNormalizer import TopKNormalizer
|
||||
from extensions.middle.FakeSplitOutputs import AddFakeOutputsToSplit
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.ops.ReduceOps import ReduceOp
|
||||
from extensions.ops.activation_ops import Activation
|
||||
@ -272,6 +273,8 @@ postprocessing_op_nodes = {
|
||||
'Assign': assign_add_output_result,
|
||||
'TensorIterator': ti_add_edge_attrs,
|
||||
'TopK': TopKNormalizer.normalize_outputs,
|
||||
# Call normalize Split outputs for generated IR by ir-reader
|
||||
'Split': AddFakeOutputsToSplit.split_normalize_outputs
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user