Fixed mapping of input name (#3737)

* Fixed mapping of input name

* Fixed unit tests

* Fixed mapping of input name

* Fixed unit tests

* attributes check fix

* PEP8 code format

* code duplicate removal

* variable rename
This commit is contained in:
Anastasia Popova 2020-12-28 22:43:54 +03:00 committed by GitHub
parent 37b6e75730
commit 631d452258
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 6 deletions

View File

@ -64,7 +64,14 @@ class AddMeanScaleValues(MiddleReplacementPattern):
for dst in input_node.out_port(0).get_destinations():
if dst.node.soft_get('type') != 'ShapeOf':
# After the insertion of additional operations model optimizer
# should keep the link to the input layer. Parameter node in framework
# should map to parameter node in IR.
# For this reason 'fw_tensor_debug_info' should be kept in data node.
fw_name = input_node.out_node(0)['fw_tensor_debug_info']
dst.get_connection().set_source(preprocessing.out_port(0))
input_node.out_node(0)['fw_tensor_debug_info'] = fw_name
del preprocessing.out_node(0)['fw_tensor_debug_info']
input_node.out_port(0).connect(preprocessing.in_port(0))

View File

@ -20,6 +20,7 @@ import numpy as np
from extensions.middle.AddMeanScaleValues import AddMeanScaleValues
from extensions.middle.ScaleInput import ScaleInput
from mo.graph.graph import Graph, Node
from mo.utils.cli_parser import get_mean_scale_dictionary, parse_tuple_pairs
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, result, connect, connect_data, \
@ -45,6 +46,25 @@ nodes = {
class AddMeanScaleValuesTest(unittest.TestCase):
def check_graph_attrs(self, graph: Graph, graph_ref: Graph, parameter_node_names: list):
for node in graph.get_op_nodes():
if node.soft_get('name') in parameter_node_names:
self.assertTrue(node.soft_get('type') == 'Parameter')
out_node = node.out_node(0)
out_node_ref = Node(graph_ref, node.id).out_node(0)
self.assertTrue(out_node['fw_tensor_debug_info'] == out_node_ref['fw_tensor_debug_info'])
else:
if 0 in node.out_nodes():
out_node = node.out_node(0)
self.assertFalse('fw_tensor_debug_info' in out_node)
def set_graph_attrs(self, graph: Graph, parameter_node_names: list):
for node in graph.get_op_nodes():
if node.soft_get('name') in parameter_node_names:
self.assertTrue(node.soft_get('type') == 'Parameter')
out_node = node.out_node(0)
out_node['fw_tensor_debug_info'] = ['fw_name', 0]
def test_mean_values_with_data_name(self):
graph_ref = build_graph(nodes, [
*connect('parameter', '0:add_mean'),
@ -58,18 +78,21 @@ class AddMeanScaleValuesTest(unittest.TestCase):
argv = Namespace(mean_scale_values=mean_scale)
graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'
AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])
def test_mean_values_without_data_name(self):
graph_ref = build_graph(nodes, [
*connect('parameter', '0:add_mean'),
*connect('mean', '1:add_mean'),
*connect('add_mean', 'result'),
])
], {'parameter': {'name': 'None'}})
mean_values = parse_tuple_pairs('(1,2,3)')
scale_values = parse_tuple_pairs('')
@ -78,11 +101,14 @@ class AddMeanScaleValuesTest(unittest.TestCase):
graph = build_graph(nodes, [*connect('parameter', 'result')], {'parameter': {'name': 'None'}},
nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['None'])
self.set_graph_attrs(graph_ref, ['None'])
graph.graph['layout'] = 'NCHW'
AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['None'])
def test_mean_values_explicit_and_optimized(self):
graph_ref = build_graph(nodes, [
@ -96,6 +122,8 @@ class AddMeanScaleValuesTest(unittest.TestCase):
'parameter_2': {'mean': np.array([0., 0., 0.])}})
graph = build_graph(nodes, [*connect('parameter', 'result'), *connect('parameter_2', 'result_2')],
nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter', 'parameter_2'])
self.set_graph_attrs(graph_ref, ['parameter', 'parameter_2'])
graph.graph['layout'] = 'NCHW'
AddMeanScaleValues().find_and_replace_pattern(graph)
@ -103,6 +131,7 @@ class AddMeanScaleValuesTest(unittest.TestCase):
self.assertTrue(flag, resp)
(flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter', 'parameter_2'])
def test_mean_values_explicit_and_scale_values_optimized(self):
graph_ref = build_graph(nodes, [
@ -113,11 +142,14 @@ class AddMeanScaleValuesTest(unittest.TestCase):
argv = Namespace(mean_scale_values={'parameter': {'scale': np.array([1.]), 'mean': np.array([1., 2., 3.])}})
graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'
AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])
def test_mean_values_optimized_and_scale_values_explicit(self):
graph_ref = build_graph(nodes, [
@ -129,11 +161,14 @@ class AddMeanScaleValuesTest(unittest.TestCase):
argv = Namespace(
mean_scale_values={'parameter': {'scale': np.array([1., 2., 3.]), 'mean': np.array([0., 0., 0.])}})
graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'
AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])
def test_mean_values_explicit_and_scale_values_explicit(self):
graph_ref = build_graph(nodes, [
@ -147,11 +182,14 @@ class AddMeanScaleValuesTest(unittest.TestCase):
argv = Namespace(mean_scale_values=[[np.array([1., 2., 3.]), np.array([1., 2., 3.])]])
graph = build_graph(nodes, [*connect('parameter', 'result')],
nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'
AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])
def test_mean_values_explicit_and_scale_values_explicit_on_cutted_graph(self):
"""
@ -173,6 +211,8 @@ class AddMeanScaleValuesTest(unittest.TestCase):
graph = build_graph(
nodes, [*connect('parameter', 'result'), *connect('parameter_2', 'op'), *connect('op', 'result_2')],
{'parameter_2': {'initial_node_name': 'op'}}, nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter', 'parameter_2'])
self.set_graph_attrs(graph_ref, ['parameter', 'parameter_2'])
graph.graph['layout'] = 'NCHW'
AddMeanScaleValues().find_and_replace_pattern(graph)
@ -180,6 +220,7 @@ class AddMeanScaleValuesTest(unittest.TestCase):
self.assertTrue(flag, resp)
(flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter', 'parameter_2'])
def test_mean_values_explicit_and_scale_values_explicit_with_shape_of(self):
graph_ref = build_graph(nodes,
@ -203,6 +244,8 @@ class AddMeanScaleValuesTest(unittest.TestCase):
*connect('shape_of', 'result_2'),
],
nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'
AddMeanScaleValues().find_and_replace_pattern(graph)
@ -210,29 +253,34 @@ class AddMeanScaleValuesTest(unittest.TestCase):
self.assertTrue(flag, resp)
(flag, resp) = compare_graphs(graph, graph_ref, 'result_2', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])
class ScaleInputTests(unittest.TestCase):
def test_scale_input(self):
graph_ref = build_graph(nodes, [
*connect('parameter', '0:mul_scale'),
*connect('scale', '1:mul_scale'),
*connect('mul_scale', 'result'),
], {'scale': {'shape': [1, 1, 1, 1], 'value': np.array(1/255)},
'scale_d': {'shape': [1, 1, 1, 1], 'value': np.array(1/255)}})
], {'scale': {'shape': [1, 1, 1, 1], 'value': np.array(1 / 255)},
'scale_d': {'shape': [1, 1, 1, 1], 'value': np.array(1 / 255)}})
graph = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True, cli=Namespace(scale=255))
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'
ScaleInput().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])
def test_scale_input_2(self):
graph_ref = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True)
graph = build_graph(nodes, connect('parameter', 'result'), nodes_with_edges_only=True, cli=Namespace(scale=1))
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'
ScaleInput().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])