Add keep split output ports without consumers (#5136)
* Add keep split output ports without consumers * Fix ir reader for split outputs * Update unit tests * Refactoring code according to review * Fix unit test * Fix
This commit is contained in:
parent
6d2740a335
commit
efeff1ee3e
@ -106,6 +106,8 @@ class RemoveConstToResult(BackReplacementPattern):
|
|||||||
Transformation looks for a constant sub-graph followed by Result operation.
|
Transformation looks for a constant sub-graph followed by Result operation.
|
||||||
If sub-graph is Const->data->Result -- then all three nodes are removed.
|
If sub-graph is Const->data->Result -- then all three nodes are removed.
|
||||||
If there is more complex constant sub-graph -- then only Result node is removed.
|
If there is more complex constant sub-graph -- then only Result node is removed.
|
||||||
|
If Result node has keep_output_port attribute True the node will not to be removed from graph but
|
||||||
|
the Result node will not to be saved to IR. Only port will be kept in IR.
|
||||||
|
|
||||||
Currently IE is unable to handle such graph so this transformation is a work around for such case.
|
Currently IE is unable to handle such graph so this transformation is a work around for such case.
|
||||||
For instance, this case appears for Wide and Deep model.
|
For instance, this case appears for Wide and Deep model.
|
||||||
@ -123,7 +125,8 @@ class RemoveConstToResult(BackReplacementPattern):
|
|||||||
return dict(
|
return dict(
|
||||||
nodes=[
|
nodes=[
|
||||||
('const_data', {'kind': 'data', 'value': lambda value: value is not None}),
|
('const_data', {'kind': 'data', 'value': lambda value: value is not None}),
|
||||||
('result_node', {'type': 'Result', 'kind': 'op'}),
|
('result_node', {'type': 'Result', 'kind': 'op',
|
||||||
|
'keep_output_port': lambda attr: not attr}),
|
||||||
],
|
],
|
||||||
edges=[
|
edges=[
|
||||||
('const_data', 'result_node')
|
('const_data', 'result_node')
|
||||||
|
@ -107,7 +107,7 @@ class RemoveConstToResultReplacementTest(unittest.TestCase):
|
|||||||
nodes = [
|
nodes = [
|
||||||
('const_node', {'type': 'Const', 'kind': 'op'}),
|
('const_node', {'type': 'Const', 'kind': 'op'}),
|
||||||
('const_data', {'kind': 'data', 'value': np.array(5)}),
|
('const_data', {'kind': 'data', 'value': np.array(5)}),
|
||||||
('result_node', {'type': 'Result', 'kind': 'op'}),
|
('result_node', {'type': 'Result', 'kind': 'op', 'keep_output_port': False}),
|
||||||
|
|
||||||
('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
|
('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
|
||||||
('placeholder_1_data', {'kind': 'data'}),
|
('placeholder_1_data', {'kind': 'data'}),
|
||||||
@ -150,6 +150,58 @@ class RemoveConstToResultReplacementTest(unittest.TestCase):
|
|||||||
self.assertNotIn('const_data', graph.node)
|
self.assertNotIn('const_data', graph.node)
|
||||||
self.assertNotIn('result_node', graph.node)
|
self.assertNotIn('result_node', graph.node)
|
||||||
|
|
||||||
|
|
||||||
|
def test_only_consumer_keep_result(self):
|
||||||
|
"""Result node is only consumer of Const data node"""
|
||||||
|
nodes = [
|
||||||
|
('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
|
||||||
|
('placeholder_1_data', {'kind': 'data'}),
|
||||||
|
('placeholder_2', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
|
||||||
|
('placeholder_2_data', {'kind': 'data'}),
|
||||||
|
('shape_of', {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'}),
|
||||||
|
('shape_of_data', {'kind': 'data'}),
|
||||||
|
('split', {'type': 'Split', 'kind': 'op', 'op': 'Split'}),
|
||||||
|
('split_data1', {'kind': 'data'}),
|
||||||
|
('split_data2', {'kind': 'data'}),
|
||||||
|
('result_node1', {'type': 'Result', 'kind': 'op', 'keep_output_port': True}),
|
||||||
|
|
||||||
|
('mul', {'type': 'Mul', 'kind': 'op', 'op': 'Mul'}),
|
||||||
|
('mul_data', {'kind': 'data'}),
|
||||||
|
('result_node2', {'type': 'Result', 'kind': 'op'}),
|
||||||
|
]
|
||||||
|
edges = [
|
||||||
|
('placeholder_1', 'placeholder_1_data'),
|
||||||
|
('placeholder_2', 'placeholder_2_data'),
|
||||||
|
('placeholder_1_data', 'shape_of'),
|
||||||
|
('shape_of', 'shape_of_data'),
|
||||||
|
('shape_of_data', 'split'),
|
||||||
|
('split', 'split_data1', {'in': 0}),
|
||||||
|
('split', 'split_data2', {'in': 1}),
|
||||||
|
|
||||||
|
('split_data1', 'result_node1'),
|
||||||
|
('split_data2', 'mul'),
|
||||||
|
('placeholder_2_data', 'mul'),
|
||||||
|
('mul', 'mul_data'),
|
||||||
|
('mul_data', 'result_node2'),
|
||||||
|
]
|
||||||
|
|
||||||
|
graph = build_graph_with_attrs(
|
||||||
|
nodes_with_attrs=nodes,
|
||||||
|
edges_with_attrs=edges,
|
||||||
|
)
|
||||||
|
graph_ref = build_graph_with_attrs(
|
||||||
|
nodes_with_attrs=nodes,
|
||||||
|
edges_with_attrs=edges,
|
||||||
|
)
|
||||||
|
tested_pattern = RemoveConstToResult()
|
||||||
|
tested_pattern.find_and_replace_pattern(graph)
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, last_node='mul_data')
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
self.assertIn('split_data1', graph.node)
|
||||||
|
self.assertIn('split_data2', graph.node)
|
||||||
|
self.assertIn('result_node1', graph.node)
|
||||||
|
|
||||||
|
|
||||||
def test_two_consumers(self):
|
def test_two_consumers(self):
|
||||||
"""Const data node has two consumers: Result and ReLu"""
|
"""Const data node has two consumers: Result and ReLu"""
|
||||||
nodes = [
|
nodes = [
|
||||||
@ -190,3 +242,34 @@ class RemoveConstToResultReplacementTest(unittest.TestCase):
|
|||||||
(flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data')
|
(flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data')
|
||||||
self.assertTrue(flag, resp)
|
self.assertTrue(flag, resp)
|
||||||
self.assertNotIn('result_node', graph.node)
|
self.assertNotIn('result_node', graph.node)
|
||||||
|
|
||||||
|
|
||||||
|
def test_two_consumers_keep_outputs(self):
|
||||||
|
"""Const data node has two consumers: Result and ReLu"""
|
||||||
|
nodes = [
|
||||||
|
('const_node', {'type': 'Const', 'kind': 'op'}),
|
||||||
|
('const_data', {'kind': 'data', 'value': np.array(5)}),
|
||||||
|
('result_node', {'type': 'Result', 'kind': 'op', 'keep_output_port': True}),
|
||||||
|
('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
|
||||||
|
('relu_1_data', {'kind': 'data'}),
|
||||||
|
]
|
||||||
|
edges = [
|
||||||
|
('const_node', 'const_data'),
|
||||||
|
('const_data', 'result_node'),
|
||||||
|
('const_data', 'relu_1'),
|
||||||
|
('relu_1', 'relu_1_data')
|
||||||
|
]
|
||||||
|
|
||||||
|
graph = build_graph_with_attrs(
|
||||||
|
nodes_with_attrs=nodes,
|
||||||
|
edges_with_attrs=edges,
|
||||||
|
)
|
||||||
|
graph_ref = build_graph_with_attrs(
|
||||||
|
nodes_with_attrs=nodes,
|
||||||
|
edges_with_attrs=edges,
|
||||||
|
)
|
||||||
|
tested_pattern = RemoveConstToResult()
|
||||||
|
tested_pattern.find_and_replace_pattern(graph)
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data')
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
self.assertIn('result_node', graph.node)
|
||||||
|
@ -52,7 +52,7 @@ class TopKNormalizer(BackReplacementPattern):
|
|||||||
"""
|
"""
|
||||||
if node.out_port(0).disconnected():
|
if node.out_port(0).disconnected():
|
||||||
output = Result(node.graph, {'name': node.name + '/Result_port_0/',
|
output = Result(node.graph, {'name': node.name + '/Result_port_0/',
|
||||||
'remove_from_xml': node.has_and_set('remove_values_output')}).create_node()
|
'keep_output_port': node.has_and_set('remove_values_output')}).create_node()
|
||||||
node.out_port(0).get_connection().set_destination(output.in_port(0))
|
node.out_port(0).get_connection().set_destination(output.in_port(0))
|
||||||
if node.out_port(1).disconnected():
|
if node.out_port(1).disconnected():
|
||||||
output = Result(node.graph, {'name': node.name + '/Result_port_1/'}).create_node()
|
output = Result(node.graph, {'name': node.name + '/Result_port_1/'}).create_node()
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
|
from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
|
||||||
from mo.graph.graph import Graph
|
from mo.graph.graph import Graph, Node
|
||||||
from mo.middle.replacement import MiddleReplacementPattern
|
from mo.middle.replacement import MiddleReplacementPattern
|
||||||
from mo.ops.result import Result
|
from mo.ops.result import Result
|
||||||
|
|
||||||
@ -21,23 +21,19 @@ class AddFakeOutputsToSplit(MiddleReplacementPattern):
|
|||||||
def run_after(self):
|
def run_after(self):
|
||||||
return [TensorIteratorMerge]
|
return [TensorIteratorMerge]
|
||||||
|
|
||||||
@staticmethod
|
def find_and_replace_pattern(self, graph: Graph):
|
||||||
def pattern():
|
for split_node in graph.get_op_nodes(op='Split'):
|
||||||
return dict(
|
AddFakeOutputsToSplit.split_normalize_outputs(split_node)
|
||||||
nodes=[('op', dict(kind='op', op='Split'))],
|
|
||||||
edges=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def replace_pattern(graph: Graph, match: dict):
|
def split_normalize_outputs(node: Node):
|
||||||
node = match['op']
|
|
||||||
|
|
||||||
if node.has_valid('out_ports_count') and len(node.out_edges()) < node.out_ports_count:
|
if node.has_valid('out_ports_count') and len(node.out_edges()) < node.out_ports_count:
|
||||||
for p in range(node.out_ports_count):
|
for p in range(node.out_ports_count):
|
||||||
if p not in node.out_ports():
|
if p not in node.out_ports():
|
||||||
node.add_output_port(p)
|
node.add_output_port(p)
|
||||||
if node.out_port(p).disconnected():
|
if node.out_port(p).disconnected():
|
||||||
res_node = Result(graph, {'name': node.name + '/Fake_output_{}/'.format(p)}).create_node()
|
res_node = Result(node.graph, {'name': node.name + '/Fake_output_{}/'.format(p),
|
||||||
|
'keep_output_port': True}).create_node()
|
||||||
node.out_port(p).connect(res_node.in_port(0))
|
node.out_port(p).connect(res_node.in_port(0))
|
||||||
|
|
||||||
|
|
||||||
@ -76,4 +72,4 @@ class AddFakeOutputsToVariadicSplit(MiddleReplacementPattern):
|
|||||||
if not node.has_valid('out_ports_count'):
|
if not node.has_valid('out_ports_count'):
|
||||||
node['out_ports_count'] = len(size_splits)
|
node['out_ports_count'] = len(size_splits)
|
||||||
|
|
||||||
AddFakeOutputsToSplit().replace_pattern(graph, match)
|
AddFakeOutputsToSplit().split_normalize_outputs(node)
|
||||||
|
@ -247,7 +247,7 @@ def serialize_node_attributes(
|
|||||||
unsupported):
|
unsupported):
|
||||||
# the Result op may be marked so it should not appear in the IR. For example, refer to transformation
|
# the Result op may be marked so it should not appear in the IR. For example, refer to transformation
|
||||||
# model-optimizer/extensions/back/TopKNormalizer.py
|
# model-optimizer/extensions/back/TopKNormalizer.py
|
||||||
if isinstance(node, Node) and node.soft_get('result' == 'Result') and node.has_and_set('remove_from_xml'):
|
if isinstance(node, Node) and node.soft_get('type') == 'Result' and node.has_and_set('keep_output_port'):
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
for s in schema:
|
for s in schema:
|
||||||
|
@ -7,6 +7,7 @@ import os
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from extensions.back.TopKNormalizer import TopKNormalizer
|
from extensions.back.TopKNormalizer import TopKNormalizer
|
||||||
|
from extensions.middle.FakeSplitOutputs import AddFakeOutputsToSplit
|
||||||
from extensions.ops.Cast import Cast
|
from extensions.ops.Cast import Cast
|
||||||
from extensions.ops.ReduceOps import ReduceOp
|
from extensions.ops.ReduceOps import ReduceOp
|
||||||
from extensions.ops.activation_ops import Activation
|
from extensions.ops.activation_ops import Activation
|
||||||
@ -272,6 +273,8 @@ postprocessing_op_nodes = {
|
|||||||
'Assign': assign_add_output_result,
|
'Assign': assign_add_output_result,
|
||||||
'TensorIterator': ti_add_edge_attrs,
|
'TensorIterator': ti_add_edge_attrs,
|
||||||
'TopK': TopKNormalizer.normalize_outputs,
|
'TopK': TopKNormalizer.normalize_outputs,
|
||||||
|
# Call normalize Split outputs for generated IR by ir-reader
|
||||||
|
'Split': AddFakeOutputsToSplit.split_normalize_outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user