[ MO ] Extended Const->Result replacer (#1688)
* [ MO ] Extended Const->Result replacer
This commit is contained in:
parent
cb8892ca2b
commit
3cc7896e42
@ -116,36 +116,38 @@ class CreateConstNodesReplacement(BackReplacementPattern):
|
||||
|
||||
class RemoveConstToResult(BackReplacementPattern):
|
||||
"""
|
||||
Transformation looks for a sub-graph "Const->Result" and removes Result node.
|
||||
Currently IE is unable to handle such graph so this transformation removes to work around this case.
|
||||
Transformation looks for a constant sub-graph followed by Result operation.
|
||||
If sub-graph is Const->data->Result -- then all three nodes are removed.
|
||||
If there is more complex constant sub-graph -- then only Result node is removed.
|
||||
|
||||
Currently IE is unable to handle such graph so this transformation is a work around for such case.
|
||||
For instance, this case appears for Wide and Deep model.
|
||||
"""
|
||||
enabled = True
|
||||
force_clean_up = True
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('const_node', {'type': 'Const', 'kind': 'op'}),
|
||||
('const_data', {'kind': 'data'}),
|
||||
('const_data', {'kind': 'data', 'value': lambda value: value is not None}),
|
||||
('result_node', {'type': 'Result', 'kind': 'op'}),
|
||||
],
|
||||
edges=[
|
||||
('const_node', 'const_data'),
|
||||
('const_data', 'result_node')
|
||||
]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: dict):
|
||||
const_node = match['const_node']
|
||||
const_data_node = match['const_data']
|
||||
result_node = match['result_node']
|
||||
nodes_to_remove = [result_node.id]
|
||||
|
||||
# in case only const data consumer that is the result node, remove the whole sub-graph
|
||||
if len(const_node.out_port(0).get_destinations()) == 1:
|
||||
nodes_to_remove.append(const_node.id)
|
||||
parent_node = result_node.in_port(0).get_source().node
|
||||
if parent_node.soft_get('type') == 'Const' and len(parent_node.out_port(0).get_destinations()) == 1:
|
||||
nodes_to_remove.append(parent_node.id)
|
||||
nodes_to_remove.append(const_data_node.id)
|
||||
|
||||
graph.remove_nodes_from(nodes_to_remove)
|
||||
@ -174,7 +176,6 @@ class NormalizeTI(BackReplacementPattern):
|
||||
ti.output_port_map = [dict(unique_r) for unique_r in set([tuple(rec.items()) for rec in ti.output_port_map])]
|
||||
ti.back_edges = [dict(unique_rec) for unique_rec in set([tuple(rec.items()) for rec in ti.back_edges])]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def external_nodes_normalization(ti):
|
||||
"""
|
||||
|
@ -118,7 +118,7 @@ class RemoveConstToResultReplacementTest(unittest.TestCase):
|
||||
"""Result node is only consumer of Const data node"""
|
||||
nodes = [
|
||||
('const_node', {'type': 'Const', 'kind': 'op'}),
|
||||
('const_data', {'kind': 'data'}),
|
||||
('const_data', {'kind': 'data', 'value': np.array(5)}),
|
||||
('result_node', {'type': 'Result', 'kind': 'op'}),
|
||||
|
||||
('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
|
||||
@ -166,7 +166,7 @@ class RemoveConstToResultReplacementTest(unittest.TestCase):
|
||||
"""Const data node has two consumers: Result and ReLu"""
|
||||
nodes = [
|
||||
('const_node', {'type': 'Const', 'kind': 'op'}),
|
||||
('const_data', {'kind': 'data'}),
|
||||
('const_data', {'kind': 'data', 'value': np.array(5)}),
|
||||
('result_node', {'type': 'Result', 'kind': 'op'}),
|
||||
('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
|
||||
('relu_1_data', {'kind': 'data'}),
|
||||
|
Loading…
Reference in New Issue
Block a user