Correct removing nodes from graph and add test for ConstToResult transform (#1083)

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev
2020-06-24 17:39:08 +05:00
committed by GitHub
parent c26ec8b312
commit 92c1333653
2 changed files with 92 additions and 2 deletions

View File

@@ -164,7 +164,7 @@ class RemoveConstToResult(BackReplacementPattern):
nodes_to_remove.append(const_node.id)
nodes_to_remove.append(const_data_node.id)
graph.remove_node(nodes_to_remove)
graph.remove_nodes_from(nodes_to_remove)
class NormalizeTI(BackReplacementPattern):

View File

@@ -17,7 +17,7 @@ import unittest
import numpy as np
from extensions.back.SpecialNodesFinalization import CreateConstNodesReplacement
from extensions.back.SpecialNodesFinalization import CreateConstNodesReplacement, RemoveConstToResult
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph_with_attrs
@@ -112,3 +112,93 @@ class CreateConstNodesReplacementTest(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, last_node='next_node')
self.assertTrue(flag, resp)
class RemoveConstToResultReplacementTest(unittest.TestCase):
def test_only_consumer(self):
"""Result node is only consumer of Const data node"""
nodes = [
('const_node', {'type': 'Const', 'kind': 'op'}),
('const_data', {'kind': 'data'}),
('result_node', {'type': 'Result', 'kind': 'op'}),
('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
('placeholder_1_data', {'kind': 'data'}),
('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
('relu_1_data', {'kind': 'data'}),
]
edges = [
('const_node', 'const_data'),
('const_data', 'result_node'),
('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'relu_1'),
('relu_1', 'relu_1_data')
]
new_nodes=[
('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
('placeholder_1_data', {'kind': 'data'}),
('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
('relu_1_data', {'kind': 'data'}),
]
new_edges=[
('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'relu_1'),
('relu_1', 'relu_1_data')
]
graph = build_graph_with_attrs(
nodes_with_attrs=nodes,
edges_with_attrs=edges,
)
graph_ref = build_graph_with_attrs(
nodes_with_attrs=new_nodes,
edges_with_attrs=new_edges,
)
tested_pattern = RemoveConstToResult()
tested_pattern.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data')
self.assertTrue(flag, resp)
self.assertNotIn('const_node', graph.node)
self.assertNotIn('const_data', graph.node)
self.assertNotIn('result_node', graph.node)
def test_two_consumers(self):
"""Const data node has two consumers: Result and ReLu"""
nodes = [
('const_node', {'type': 'Const', 'kind': 'op'}),
('const_data', {'kind': 'data'}),
('result_node', {'type': 'Result', 'kind': 'op'}),
('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
('relu_1_data', {'kind': 'data'}),
]
edges = [
('const_node', 'const_data'),
('const_data', 'result_node'),
('const_data', 'relu_1'),
('relu_1', 'relu_1_data')
]
new_nodes=[
('const_node', {'type': 'Const', 'kind': 'op'}),
('const_data', {'kind': 'data'}),
('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
('relu_1_data', {'kind': 'data'}),
]
new_edges=[
('const_node', 'const_data'),
('const_data', 'relu_1'),
('relu_1', 'relu_1_data')
]
graph = build_graph_with_attrs(
nodes_with_attrs=nodes,
edges_with_attrs=edges,
)
graph_ref = build_graph_with_attrs(
nodes_with_attrs=new_nodes,
edges_with_attrs=new_edges,
)
tested_pattern = RemoveConstToResult()
tested_pattern.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data')
self.assertTrue(flag, resp)
self.assertNotIn('result_node', graph.node)