Correct removing nodes from graph and add test for ConstToResult transform (#1083)
Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
@@ -164,7 +164,7 @@ class RemoveConstToResult(BackReplacementPattern):
|
||||
nodes_to_remove.append(const_node.id)
|
||||
nodes_to_remove.append(const_data_node.id)
|
||||
|
||||
graph.remove_node(nodes_to_remove)
|
||||
graph.remove_nodes_from(nodes_to_remove)
|
||||
|
||||
|
||||
class NormalizeTI(BackReplacementPattern):
|
||||
|
||||
@@ -17,7 +17,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.back.SpecialNodesFinalization import CreateConstNodesReplacement
|
||||
from extensions.back.SpecialNodesFinalization import CreateConstNodesReplacement, RemoveConstToResult
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph_with_attrs
|
||||
|
||||
@@ -112,3 +112,93 @@ class CreateConstNodesReplacementTest(unittest.TestCase):
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, last_node='next_node')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
|
||||
class RemoveConstToResultReplacementTest(unittest.TestCase):
|
||||
def test_only_consumer(self):
|
||||
"""Result node is only consumer of Const data node"""
|
||||
nodes = [
|
||||
('const_node', {'type': 'Const', 'kind': 'op'}),
|
||||
('const_data', {'kind': 'data'}),
|
||||
('result_node', {'type': 'Result', 'kind': 'op'}),
|
||||
|
||||
('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
|
||||
('placeholder_1_data', {'kind': 'data'}),
|
||||
('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
|
||||
('relu_1_data', {'kind': 'data'}),
|
||||
]
|
||||
edges = [
|
||||
('const_node', 'const_data'),
|
||||
('const_data', 'result_node'),
|
||||
|
||||
('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'relu_1'),
|
||||
('relu_1', 'relu_1_data')
|
||||
]
|
||||
new_nodes=[
|
||||
('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}),
|
||||
('placeholder_1_data', {'kind': 'data'}),
|
||||
('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
|
||||
('relu_1_data', {'kind': 'data'}),
|
||||
]
|
||||
new_edges=[
|
||||
('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'relu_1'),
|
||||
('relu_1', 'relu_1_data')
|
||||
]
|
||||
|
||||
graph = build_graph_with_attrs(
|
||||
nodes_with_attrs=nodes,
|
||||
edges_with_attrs=edges,
|
||||
)
|
||||
graph_ref = build_graph_with_attrs(
|
||||
nodes_with_attrs=new_nodes,
|
||||
edges_with_attrs=new_edges,
|
||||
)
|
||||
tested_pattern = RemoveConstToResult()
|
||||
tested_pattern.find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data')
|
||||
self.assertTrue(flag, resp)
|
||||
self.assertNotIn('const_node', graph.node)
|
||||
self.assertNotIn('const_data', graph.node)
|
||||
self.assertNotIn('result_node', graph.node)
|
||||
|
||||
def test_two_consumers(self):
|
||||
"""Const data node has two consumers: Result and ReLu"""
|
||||
nodes = [
|
||||
('const_node', {'type': 'Const', 'kind': 'op'}),
|
||||
('const_data', {'kind': 'data'}),
|
||||
('result_node', {'type': 'Result', 'kind': 'op'}),
|
||||
('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
|
||||
('relu_1_data', {'kind': 'data'}),
|
||||
]
|
||||
edges = [
|
||||
('const_node', 'const_data'),
|
||||
('const_data', 'result_node'),
|
||||
('const_data', 'relu_1'),
|
||||
('relu_1', 'relu_1_data')
|
||||
]
|
||||
new_nodes=[
|
||||
('const_node', {'type': 'Const', 'kind': 'op'}),
|
||||
('const_data', {'kind': 'data'}),
|
||||
('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}),
|
||||
('relu_1_data', {'kind': 'data'}),
|
||||
]
|
||||
new_edges=[
|
||||
('const_node', 'const_data'),
|
||||
('const_data', 'relu_1'),
|
||||
('relu_1', 'relu_1_data')
|
||||
]
|
||||
|
||||
graph = build_graph_with_attrs(
|
||||
nodes_with_attrs=nodes,
|
||||
edges_with_attrs=edges,
|
||||
)
|
||||
graph_ref = build_graph_with_attrs(
|
||||
nodes_with_attrs=new_nodes,
|
||||
edges_with_attrs=new_edges,
|
||||
)
|
||||
tested_pattern = RemoveConstToResult()
|
||||
tested_pattern.find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data')
|
||||
self.assertTrue(flag, resp)
|
||||
self.assertNotIn('result_node', graph.node)
|
||||
|
||||
Reference in New Issue
Block a user