Fixed regression after mapping bug fix (#3770)

* Fixed regression after mapping bug fix

* Unit test for atribute absence.
This commit is contained in:
Anastasia Popova 2020-12-30 19:41:10 +03:00 committed by GitHub
parent fe229ba19d
commit c6e57ef99f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 3 deletions

View File

@ -68,10 +68,13 @@ class AddMeanScaleValues(MiddleReplacementPattern):
# should keep the link to the input layer. Parameter node in framework
# should map to parameter node in IR.
# For this reason 'fw_tensor_debug_info' should be kept in data node.
fw_name = input_node.out_node(0)['fw_tensor_debug_info']
has_debug_info = 'fw_tensor_debug_info' in input_node.out_node(0)
if has_debug_info:
fw_name = input_node.out_node(0)['fw_tensor_debug_info']
dst.get_connection().set_source(preprocessing.out_port(0))
input_node.out_node(0)['fw_tensor_debug_info'] = fw_name
del preprocessing.out_node(0)['fw_tensor_debug_info']
if has_debug_info:
input_node.out_node(0)['fw_tensor_debug_info'] = fw_name
del preprocessing.out_node(0)['fw_tensor_debug_info']
input_node.out_port(0).connect(preprocessing.in_port(0))

View File

@ -284,3 +284,22 @@ class AddMeanScaleValuesTest(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])
def test_debug_info_absence(self):
graph_ref = build_graph(nodes, [
*connect('parameter', '0:add_mean'),
*connect('mean', '1:add_mean'),
*connect('add_mean', '0:mul_scale'),
*connect('scale', '1:mul_scale'),
*connect('mul_scale', 'result'),
])
argv = Namespace(mean_scale_values=[[np.array([1., 2., 3.]), np.array([1., 2., 3.])]])
graph = build_graph(nodes, [*connect('parameter', 'result')],
nodes_with_edges_only=True, cli=argv)
graph.graph['layout'] = 'NCHW'
AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, [])