Fixed regression after mapping bug fix (#3770)
* Fixed regression after mapping bug fix * Unit test for atribute absence.
This commit is contained in:
parent
fe229ba19d
commit
c6e57ef99f
@ -68,8 +68,11 @@ class AddMeanScaleValues(MiddleReplacementPattern):
|
|||||||
# should keep the link to the input layer. Parameter node in framework
|
# should keep the link to the input layer. Parameter node in framework
|
||||||
# should map to parameter node in IR.
|
# should map to parameter node in IR.
|
||||||
# For this reason 'fw_tensor_debug_info' should be kept in data node.
|
# For this reason 'fw_tensor_debug_info' should be kept in data node.
|
||||||
|
has_debug_info = 'fw_tensor_debug_info' in input_node.out_node(0)
|
||||||
|
if has_debug_info:
|
||||||
fw_name = input_node.out_node(0)['fw_tensor_debug_info']
|
fw_name = input_node.out_node(0)['fw_tensor_debug_info']
|
||||||
dst.get_connection().set_source(preprocessing.out_port(0))
|
dst.get_connection().set_source(preprocessing.out_port(0))
|
||||||
|
if has_debug_info:
|
||||||
input_node.out_node(0)['fw_tensor_debug_info'] = fw_name
|
input_node.out_node(0)['fw_tensor_debug_info'] = fw_name
|
||||||
del preprocessing.out_node(0)['fw_tensor_debug_info']
|
del preprocessing.out_node(0)['fw_tensor_debug_info']
|
||||||
|
|
||||||
|
@ -284,3 +284,22 @@ class AddMeanScaleValuesTest(unittest.TestCase):
|
|||||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||||
self.assertTrue(flag, resp)
|
self.assertTrue(flag, resp)
|
||||||
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
||||||
|
|
||||||
|
def test_debug_info_absence(self):
|
||||||
|
graph_ref = build_graph(nodes, [
|
||||||
|
*connect('parameter', '0:add_mean'),
|
||||||
|
*connect('mean', '1:add_mean'),
|
||||||
|
*connect('add_mean', '0:mul_scale'),
|
||||||
|
*connect('scale', '1:mul_scale'),
|
||||||
|
*connect('mul_scale', 'result'),
|
||||||
|
])
|
||||||
|
|
||||||
|
argv = Namespace(mean_scale_values=[[np.array([1., 2., 3.]), np.array([1., 2., 3.])]])
|
||||||
|
graph = build_graph(nodes, [*connect('parameter', 'result')],
|
||||||
|
nodes_with_edges_only=True, cli=argv)
|
||||||
|
graph.graph['layout'] = 'NCHW'
|
||||||
|
|
||||||
|
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
self.check_graph_attrs(graph, graph_ref, [])
|
||||||
|
Loading…
Reference in New Issue
Block a user