[ MO ] Extended Const->Result replacer (#1688)

* [ MO ] Extended Const->Result replacer
This commit is contained in:
Evgenya Stepyreva 2020-08-10 15:36:05 +03:00 committed by GitHub
parent cb8892ca2b
commit 3cc7896e42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 11 deletions

View File

@ -116,36 +116,38 @@ class CreateConstNodesReplacement(BackReplacementPattern):
class RemoveConstToResult(BackReplacementPattern):
"""
Transformation looks for a sub-graph "Const->Result" and removes Result node.
Currently IE is unable to handle such graph so this transformation removes to work around this case.
Transformation looks for a constant sub-graph followed by Result operation.
If sub-graph is Const->data->Result -- then all three nodes are removed.
If there is more complex constant sub-graph -- then only Result node is removed.
Currently IE is unable to handle such graph so this transformation is a work around for such case.
For instance, this case appears for Wide and Deep model.
"""
enabled = True
force_clean_up = True
@staticmethod
def pattern():
return dict(
nodes=[
('const_node', {'type': 'Const', 'kind': 'op'}),
('const_data', {'kind': 'data'}),
('const_data', {'kind': 'data', 'value': lambda value: value is not None}),
('result_node', {'type': 'Result', 'kind': 'op'}),
],
edges=[
('const_node', 'const_data'),
('const_data', 'result_node')
]
)
@staticmethod
def replace_pattern(graph: Graph, match: dict):
const_node = match['const_node']
const_data_node = match['const_data']
result_node = match['result_node']
nodes_to_remove = [result_node.id]
# in case only const data consumer that is the result node, remove the whole sub-graph
if len(const_node.out_port(0).get_destinations()) == 1:
nodes_to_remove.append(const_node.id)
parent_node = result_node.in_port(0).get_source().node
if parent_node.soft_get('type') == 'Const' and len(parent_node.out_port(0).get_destinations()) == 1:
nodes_to_remove.append(parent_node.id)
nodes_to_remove.append(const_data_node.id)
graph.remove_nodes_from(nodes_to_remove)
@ -174,7 +176,6 @@ class NormalizeTI(BackReplacementPattern):
ti.output_port_map = [dict(unique_r) for unique_r in set([tuple(rec.items()) for rec in ti.output_port_map])]
ti.back_edges = [dict(unique_rec) for unique_rec in set([tuple(rec.items()) for rec in ti.back_edges])]
@staticmethod
def external_nodes_normalization(ti):
"""

View File

@ -118,7 +118,7 @@ class RemoveConstToResultReplacementTest(unittest.TestCase):
"""Result node is only consumer of Const data node"""
nodes = [
('const_node', {'type': 'Const', 'kind': 'op'}),
('const_data', {'kind': 'data'}),
('const_data', {'kind': 'data', 'value': np.array(5)}),
('result_node', {'type': 'Result', 'kind': 'op'}),
('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
@ -166,7 +166,7 @@ class RemoveConstToResultReplacementTest(unittest.TestCase):
"""Const data node has two consumers: Result and ReLu"""
nodes = [
('const_node', {'type': 'Const', 'kind': 'op'}),
('const_data', {'kind': 'data'}),
('const_data', {'kind': 'data', 'value': np.array(5)}),
('result_node', {'type': 'Result', 'kind': 'op'}),
('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
('relu_1_data', {'kind': 'data'}),