Fixed mapping of input name (#3737)
* Fixed mapping of input name * Fixed unit tests * Fixed mapping of input name * Fixed unit tests * attributes check fix * PEP8 code format * code duplicate removal * variable rename
This commit is contained in:
parent
37b6e75730
commit
631d452258
@ -64,7 +64,14 @@ class AddMeanScaleValues(MiddleReplacementPattern):
|
||||
|
||||
for dst in input_node.out_port(0).get_destinations():
|
||||
if dst.node.soft_get('type') != 'ShapeOf':
|
||||
# After the insertion of additional operations model optimizer
|
||||
# should keep the link to the input layer. Parameter node in framework
|
||||
# should map to parameter node in IR.
|
||||
# For this reason 'fw_tensor_debug_info' should be kept in data node.
|
||||
fw_name = input_node.out_node(0)['fw_tensor_debug_info']
|
||||
dst.get_connection().set_source(preprocessing.out_port(0))
|
||||
input_node.out_node(0)['fw_tensor_debug_info'] = fw_name
|
||||
del preprocessing.out_node(0)['fw_tensor_debug_info']
|
||||
|
||||
input_node.out_port(0).connect(preprocessing.in_port(0))
|
||||
|
||||
|
@ -20,6 +20,7 @@ import numpy as np
|
||||
|
||||
from extensions.middle.AddMeanScaleValues import AddMeanScaleValues
|
||||
from extensions.middle.ScaleInput import ScaleInput
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.utils.cli_parser import get_mean_scale_dictionary, parse_tuple_pairs
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, result, connect, connect_data, \
|
||||
@ -45,6 +46,25 @@ nodes = {
|
||||
|
||||
|
||||
class AddMeanScaleValuesTest(unittest.TestCase):
|
||||
def check_graph_attrs(self, graph: Graph, graph_ref: Graph, parameter_node_names: list):
|
||||
for node in graph.get_op_nodes():
|
||||
if node.soft_get('name') in parameter_node_names:
|
||||
self.assertTrue(node.soft_get('type') == 'Parameter')
|
||||
out_node = node.out_node(0)
|
||||
out_node_ref = Node(graph_ref, node.id).out_node(0)
|
||||
self.assertTrue(out_node['fw_tensor_debug_info'] == out_node_ref['fw_tensor_debug_info'])
|
||||
else:
|
||||
if 0 in node.out_nodes():
|
||||
out_node = node.out_node(0)
|
||||
self.assertFalse('fw_tensor_debug_info' in out_node)
|
||||
|
||||
def set_graph_attrs(self, graph: Graph, parameter_node_names: list):
|
||||
for node in graph.get_op_nodes():
|
||||
if node.soft_get('name') in parameter_node_names:
|
||||
self.assertTrue(node.soft_get('type') == 'Parameter')
|
||||
out_node = node.out_node(0)
|
||||
out_node['fw_tensor_debug_info'] = ['fw_name', 0]
|
||||
|
||||
def test_mean_values_with_data_name(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:add_mean'),
|
||||
@ -58,18 +78,21 @@ class AddMeanScaleValuesTest(unittest.TestCase):
|
||||
argv = Namespace(mean_scale_values=mean_scale)
|
||||
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
||||
|
||||
def test_mean_values_without_data_name(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:add_mean'),
|
||||
*connect('mean', '1:add_mean'),
|
||||
*connect('add_mean', 'result'),
|
||||
])
|
||||
], {'parameter': {'name': 'None'}})
|
||||
|
||||
mean_values = parse_tuple_pairs('(1,2,3)')
|
||||
scale_values = parse_tuple_pairs('')
|
||||
@ -78,11 +101,14 @@ class AddMeanScaleValuesTest(unittest.TestCase):
|
||||
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result')], {'parameter': {'name': 'None'}},
|
||||
nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['None'])
|
||||
self.set_graph_attrs(graph_ref, ['None'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['None'])
|
||||
|
||||
def test_mean_values_explicit_and_optimized(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
@ -96,6 +122,8 @@ class AddMeanScaleValuesTest(unittest.TestCase):
|
||||
'parameter_2': {'mean': np.array([0., 0., 0.])}})
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result'), *connect('parameter_2', 'result_2')],
|
||||
nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter', 'parameter_2'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter', 'parameter_2'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
@ -103,6 +131,7 @@ class AddMeanScaleValuesTest(unittest.TestCase):
|
||||
self.assertTrue(flag, resp)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter', 'parameter_2'])
|
||||
|
||||
def test_mean_values_explicit_and_scale_values_optimized(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
@ -113,11 +142,14 @@ class AddMeanScaleValuesTest(unittest.TestCase):
|
||||
|
||||
argv = Namespace(mean_scale_values={'parameter': {'scale': np.array([1.]), 'mean': np.array([1., 2., 3.])}})
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
||||
|
||||
def test_mean_values_optimized_and_scale_values_explicit(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
@ -129,11 +161,14 @@ class AddMeanScaleValuesTest(unittest.TestCase):
|
||||
argv = Namespace(
|
||||
mean_scale_values={'parameter': {'scale': np.array([1., 2., 3.]), 'mean': np.array([0., 0., 0.])}})
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
||||
|
||||
def test_mean_values_explicit_and_scale_values_explicit(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
@ -147,11 +182,14 @@ class AddMeanScaleValuesTest(unittest.TestCase):
|
||||
argv = Namespace(mean_scale_values=[[np.array([1., 2., 3.]), np.array([1., 2., 3.])]])
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result')],
|
||||
nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
||||
|
||||
def test_mean_values_explicit_and_scale_values_explicit_on_cutted_graph(self):
|
||||
"""
|
||||
@ -173,6 +211,8 @@ class AddMeanScaleValuesTest(unittest.TestCase):
|
||||
graph = build_graph(
|
||||
nodes, [*connect('parameter', 'result'), *connect('parameter_2', 'op'), *connect('op', 'result_2')],
|
||||
{'parameter_2': {'initial_node_name': 'op'}}, nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter', 'parameter_2'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter', 'parameter_2'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
|
||||
@ -180,6 +220,7 @@ class AddMeanScaleValuesTest(unittest.TestCase):
|
||||
self.assertTrue(flag, resp)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter', 'parameter_2'])
|
||||
|
||||
def test_mean_values_explicit_and_scale_values_explicit_with_shape_of(self):
|
||||
graph_ref = build_graph(nodes,
|
||||
@ -203,6 +244,8 @@ class AddMeanScaleValuesTest(unittest.TestCase):
|
||||
*connect('shape_of', 'result_2'),
|
||||
],
|
||||
nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
@ -210,29 +253,34 @@ class AddMeanScaleValuesTest(unittest.TestCase):
|
||||
self.assertTrue(flag, resp)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
||||
|
||||
|
||||
class ScaleInputTests(unittest.TestCase):
|
||||
def test_scale_input(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:mul_scale'),
|
||||
*connect('scale', '1:mul_scale'),
|
||||
*connect('mul_scale', 'result'),
|
||||
], {'scale': {'shape': [1, 1, 1, 1], 'value': np.array(1/255)},
|
||||
'scale_d': {'shape': [1, 1, 1, 1], 'value': np.array(1/255)}})
|
||||
], {'scale': {'shape': [1, 1, 1, 1], 'value': np.array(1 / 255)},
|
||||
'scale_d': {'shape': [1, 1, 1, 1], 'value': np.array(1 / 255)}})
|
||||
|
||||
graph = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True, cli=Namespace(scale=255))
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
ScaleInput().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
||||
|
||||
def test_scale_input_2(self):
|
||||
graph_ref = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True)
|
||||
graph = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True, cli=Namespace(scale=1))
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
ScaleInput().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
||||
|
Loading…
Reference in New Issue
Block a user