InterpolateWithConcat pass fix (#1501)

* Fix InterpolateWithConcat pass

* Add test for None case
This commit is contained in:
Anton Chetverikov 2020-07-28 14:50:32 +03:00 committed by GitHub
parent cbdfa38392
commit 4e1f7d2b96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 9 deletions

View File

@ -171,11 +171,12 @@ class InterpolateWithConcat(FrontReplacementPattern):
# Interpolate could be connected to Concat through identity operations, skipping them
next_node = self.get_single_output_destination_safely(interpolate)
while next_node.soft_get('type') != 'Concat' and next_node.has_and_set('identity'):
node = self.get_single_output_destination_safely(next_node)
if node is not None:
next_node = node
else:
break
if next_node.soft_get('type') == 'Concat':
self.make_interpolate_reshape_able(interpolate, next_node)
if next_node is not None:
while next_node.soft_get('type') != 'Concat' and next_node.has_and_set('identity'):
node = self.get_single_output_destination_safely(next_node)
if node is not None:
next_node = node
else:
break
if next_node.soft_get('type') == 'Concat':
self.make_interpolate_reshape_able(interpolate, next_node)

View File

@ -44,7 +44,8 @@ nodes = {
**regular_op_with_shaped_data('identity_11', [1, 3, 60, 160], {'identity': True, 'op': 'Identity'}),
**regular_op_with_shaped_data('concat', [1, 7, 60, 160], {'type': 'Concat', 'axis': 1, 'op': 'Concat'}),
**result(),
**result('output'),
**result('output_1'),
}
@ -110,6 +111,29 @@ class TestInterpolateConcat(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_interpolate_concat_negate(self):
graph = build_graph(nodes, [
*connect('placeholder', '0:interpolate'),
*connect('out_shape', '1:interpolate'),
*connect('interpolate', 'identity_00'),
*connect('interpolate', 'identity_01'),
*connect('identity_00', 'output'),
*connect('identity_01', 'output_1'),
], nodes_with_edges_only=True)
InterpolateWithConcat().find_and_replace_pattern(graph)
graph.clean_up()
graph_ref = build_graph(nodes, [
*connect('placeholder', '0:interpolate'),
*connect('out_shape', '1:interpolate'),
*connect('interpolate', 'identity_00'),
*connect('interpolate', 'identity_01'),
*connect('identity_00', 'output'),
*connect('identity_01', 'output_1'),
], nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)
@generate(*[
{'concat': {'axis': None}},
{'concat': {'axis': 2}},