Fixed regression after mapping bug fix (#3770)
* Fixed regression after mapping bug fix * Unit test for atribute absence.
This commit is contained in:
parent
fe229ba19d
commit
c6e57ef99f
@ -68,10 +68,13 @@ class AddMeanScaleValues(MiddleReplacementPattern):
|
||||
# should keep the link to the input layer. Parameter node in framework
|
||||
# should map to parameter node in IR.
|
||||
# For this reason 'fw_tensor_debug_info' should be kept in data node.
|
||||
fw_name = input_node.out_node(0)['fw_tensor_debug_info']
|
||||
has_debug_info = 'fw_tensor_debug_info' in input_node.out_node(0)
|
||||
if has_debug_info:
|
||||
fw_name = input_node.out_node(0)['fw_tensor_debug_info']
|
||||
dst.get_connection().set_source(preprocessing.out_port(0))
|
||||
input_node.out_node(0)['fw_tensor_debug_info'] = fw_name
|
||||
del preprocessing.out_node(0)['fw_tensor_debug_info']
|
||||
if has_debug_info:
|
||||
input_node.out_node(0)['fw_tensor_debug_info'] = fw_name
|
||||
del preprocessing.out_node(0)['fw_tensor_debug_info']
|
||||
|
||||
input_node.out_port(0).connect(preprocessing.in_port(0))
|
||||
|
||||
|
@ -284,3 +284,22 @@ class AddMeanScaleValuesTest(unittest.TestCase):
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
||||
|
||||
def test_debug_info_absence(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:add_mean'),
|
||||
*connect('mean', '1:add_mean'),
|
||||
*connect('add_mean', '0:mul_scale'),
|
||||
*connect('scale', '1:mul_scale'),
|
||||
*connect('mul_scale', 'result'),
|
||||
])
|
||||
|
||||
argv = Namespace(mean_scale_values=[[np.array([1., 2., 3.]), np.array([1., 2., 3.])]])
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result')],
|
||||
nodes_with_edges_only=True, cli=argv)
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, [])
|
||||
|
Loading…
Reference in New Issue
Block a user