InterpolateWithConcat pass fix (#1501)
* Fix InterpolateWithConcat pass * Add test for None case
This commit is contained in:
parent
cbdfa38392
commit
4e1f7d2b96
@ -171,11 +171,12 @@ class InterpolateWithConcat(FrontReplacementPattern):
|
||||
|
||||
# Interpolate could be connected to Concat through identity operations, skipping them
|
||||
next_node = self.get_single_output_destination_safely(interpolate)
|
||||
while next_node.soft_get('type') != 'Concat' and next_node.has_and_set('identity'):
|
||||
node = self.get_single_output_destination_safely(next_node)
|
||||
if node is not None:
|
||||
next_node = node
|
||||
else:
|
||||
break
|
||||
if next_node.soft_get('type') == 'Concat':
|
||||
self.make_interpolate_reshape_able(interpolate, next_node)
|
||||
if next_node is not None:
|
||||
while next_node.soft_get('type') != 'Concat' and next_node.has_and_set('identity'):
|
||||
node = self.get_single_output_destination_safely(next_node)
|
||||
if node is not None:
|
||||
next_node = node
|
||||
else:
|
||||
break
|
||||
if next_node.soft_get('type') == 'Concat':
|
||||
self.make_interpolate_reshape_able(interpolate, next_node)
|
||||
|
@ -44,7 +44,8 @@ nodes = {
|
||||
**regular_op_with_shaped_data('identity_11', [1, 3, 60, 160], {'identity': True, 'op': 'Identity'}),
|
||||
**regular_op_with_shaped_data('concat', [1, 7, 60, 160], {'type': 'Concat', 'axis': 1, 'op': 'Concat'}),
|
||||
|
||||
**result(),
|
||||
**result('output'),
|
||||
**result('output_1'),
|
||||
}
|
||||
|
||||
|
||||
@ -110,6 +111,29 @@ class TestInterpolateConcat(unittest.TestCase):
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_interpolate_concat_negate(self):
|
||||
graph = build_graph(nodes, [
|
||||
*connect('placeholder', '0:interpolate'),
|
||||
*connect('out_shape', '1:interpolate'),
|
||||
*connect('interpolate', 'identity_00'),
|
||||
*connect('interpolate', 'identity_01'),
|
||||
*connect('identity_00', 'output'),
|
||||
*connect('identity_01', 'output_1'),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
InterpolateWithConcat().find_and_replace_pattern(graph)
|
||||
graph.clean_up()
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('placeholder', '0:interpolate'),
|
||||
*connect('out_shape', '1:interpolate'),
|
||||
*connect('interpolate', 'identity_00'),
|
||||
*connect('interpolate', 'identity_01'),
|
||||
*connect('identity_00', 'output'),
|
||||
*connect('identity_01', 'output_1'),
|
||||
], nodes_with_edges_only=True)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
@generate(*[
|
||||
{'concat': {'axis': None}},
|
||||
{'concat': {'axis': 2}},
|
||||
|
Loading…
Reference in New Issue
Block a user