diff --git a/model-optimizer/extensions/front/interpolate_reshape.py b/model-optimizer/extensions/front/interpolate_reshape.py index 4b3dfc4b5b1..c75bd8f8004 100644 --- a/model-optimizer/extensions/front/interpolate_reshape.py +++ b/model-optimizer/extensions/front/interpolate_reshape.py @@ -171,11 +171,12 @@ class InterpolateWithConcat(FrontReplacementPattern): # Interpolate could be connected to Concat through identity operations, skipping them next_node = self.get_single_output_destination_safely(interpolate) - while next_node.soft_get('type') != 'Concat' and next_node.has_and_set('identity'): - node = self.get_single_output_destination_safely(next_node) - if node is not None: - next_node = node - else: - break - if next_node.soft_get('type') == 'Concat': - self.make_interpolate_reshape_able(interpolate, next_node) + if next_node is not None: + while next_node.soft_get('type') != 'Concat' and next_node.has_and_set('identity'): + node = self.get_single_output_destination_safely(next_node) + if node is not None: + next_node = node + else: + break + if next_node.soft_get('type') == 'Concat': + self.make_interpolate_reshape_able(interpolate, next_node) diff --git a/model-optimizer/extensions/front/interpolate_reshape_test.py b/model-optimizer/extensions/front/interpolate_reshape_test.py index 1a23695d5de..2499b065738 100644 --- a/model-optimizer/extensions/front/interpolate_reshape_test.py +++ b/model-optimizer/extensions/front/interpolate_reshape_test.py @@ -44,7 +44,8 @@ nodes = { **regular_op_with_shaped_data('identity_11', [1, 3, 60, 160], {'identity': True, 'op': 'Identity'}), **regular_op_with_shaped_data('concat', [1, 7, 60, 160], {'type': 'Concat', 'axis': 1, 'op': 'Concat'}), - **result(), + **result('output'), + **result('output_1'), } @@ -110,6 +111,29 @@ class TestInterpolateConcat(unittest.TestCase): (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp) + def test_interpolate_concat_negate(self): + graph = build_graph(nodes, [ + *connect('placeholder', '0:interpolate'), + *connect('out_shape', '1:interpolate'), + *connect('interpolate', 'identity_00'), + *connect('interpolate', 'identity_01'), + *connect('identity_00', 'output'), + *connect('identity_01', 'output_1'), + ], nodes_with_edges_only=True) + + InterpolateWithConcat().find_and_replace_pattern(graph) + graph.clean_up() + graph_ref = build_graph(nodes, [ + *connect('placeholder', '0:interpolate'), + *connect('out_shape', '1:interpolate'), + *connect('interpolate', 'identity_00'), + *connect('interpolate', 'identity_01'), + *connect('identity_00', 'output'), + *connect('identity_01', 'output_1'), + ], nodes_with_edges_only=True) + (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) + self.assertTrue(flag, resp) + @generate(*[ {'concat': {'axis': None}}, {'concat': {'axis': 2}},