From 29612f15e3cf1aa073d08544ac761720c289e28a Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Wed, 3 Mar 2021 11:16:56 +0300 Subject: [PATCH] Fix StridedSlice replacer order and input permutation when strides are not specified (#4545) --- .../middle/StridedSliceNormalizer.py | 21 ++- .../middle/StridedSliceNormalizer_test.py | 143 ++++++++++++++++++ .../mo/front/common/partial_infer/utils.py | 2 +- 3 files changed, 157 insertions(+), 9 deletions(-) diff --git a/model-optimizer/extensions/middle/StridedSliceNormalizer.py b/model-optimizer/extensions/middle/StridedSliceNormalizer.py index 267ba9cc8bd..44b1cbd12de 100644 --- a/model-optimizer/extensions/middle/StridedSliceNormalizer.py +++ b/model-optimizer/extensions/middle/StridedSliceNormalizer.py @@ -102,6 +102,10 @@ class StridedSliceNormalizer(MiddleReplacementPattern): from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths return [LayoutChangeForConstantShapePaths] + def run_after(self): + from extensions.middle.SliceConverter import ConvertSlice + return [ConvertSlice] + def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(type='StridedSlice'): StridedSliceNormalizer.normalize_strided_slice(graph, node) @@ -116,7 +120,8 @@ class StridedSliceNormalizer(MiddleReplacementPattern): # Until now it was not possible to set correct permutations PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:1', 'slice', 'dim_size') PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:2', 'slice', 'dim_size') - PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:3', 'slice', 'dim_size') + if node.is_in_port_connected(3): + PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:3', 'slice', 'dim_size') @staticmethod def normalize_strided_slice(graph: Graph, node: Node): @@ -157,13 +162,13 @@ class StridedSliceNormalizer(MiddleReplacementPattern): node_name = node.soft_get('name', node.id) for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]: + if i == 3 and not node.is_in_port_connected(3): + continue # no need to extend strides if they are not connected + blank_values_arr = np.zeros(num_insertions) if input_name != 'strides' else np.ones(num_insertions) blank_values_node = Const(graph, {'name': node_name + '/const_to_unroll_{}_ellipsis'.format(input_name), 'value': int64_array(blank_values_arr)}).create_node() - if i == 3 and node.in_port(3).disconnected(): - continue # no need to extend strides if they are not connected - concat_in_ports_count = 3 if ellipsis_start != 0 else 2 concat = Concat(graph, {'axis': 0, 'name': node_name + '/concat_{}'.format(input_name), 'in_ports_count': concat_in_ports_count}).create_node() @@ -190,13 +195,13 @@ class StridedSliceNormalizer(MiddleReplacementPattern): node_name = node.soft_get('name', node.id) for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]: + if i == 3 and not node.is_in_port_connected(3): + continue # no need to extend strides if they are not connected + blank_values_arr = np.zeros(num_insertions) if input_name != 'strides' else np.ones(num_insertions) blank_values_node = Const(graph, {'name': node_name + '/extend_{}_const'.format(input_name), 'value': int64_array(blank_values_arr)}).create_node() - if i == 3 and node.in_port(3).disconnected(): - continue # no need to extend strides if they are not connected - if node.in_port(i).get_source().node.soft_get('type') == 'Concat': # concat already exists concat = node.in_port(i).get_source().node @@ -227,7 +232,7 @@ class StridedSliceNormalizer(MiddleReplacementPattern): if strides is None: raise Error('StridedSlice operation for node {} supports only constant strides input'.format(node_name)) else: - strides = np.ones(slice_rank) + strides = np.ones(len(node['slices']), dtype=np.int32) num_ellipsis_inserts = len(data_shape) - slice_rank + np.count_nonzero(node.new_axis_mask) + 1 res_slices = [] diff --git a/model-optimizer/extensions/middle/StridedSliceNormalizer_test.py b/model-optimizer/extensions/middle/StridedSliceNormalizer_test.py index ba560048eb1..508c0b6226f 100644 --- a/model-optimizer/extensions/middle/StridedSliceNormalizer_test.py +++ b/model-optimizer/extensions/middle/StridedSliceNormalizer_test.py @@ -38,6 +38,12 @@ edges = ( *connect('strided_slice', 'res') ) +edges_without_strides = ( + *connect('input', '0:strided_slice'), + *connect('begin', '1:strided_slice'), + *connect('end', '2:strided_slice'), + *connect('strided_slice', 'res') +) class TestStridedSliceNormalizer(unittest.TestCase): @@ -108,6 +114,65 @@ class TestStridedSliceNormalizer(unittest.TestCase): (flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=False) self.assertTrue(flag, 'Graphs after StridedSliceNormalizer do not match to reference: {}'.format(resp)) + def test_strided_slice_extend_inputs_without_strides(self): + input_shape = (16, 100, 100, 3) + nodes = { + **valued_const_with_data('input', np.arange(np.product(input_shape)).reshape(*input_shape)), + **regular_op_with_empty_data('strided_slice', {'op': 'StridedSlice', + 'type': 'StridedSlice', + 'begin_mask': [1, 1, 1], + 'end_mask': [1, 1, 1], + 'shrink_axis_mask': [1, 0, 0], + 'new_axis_mask': [0, 0, 0], + 'ellipsis_mask': [0, 0, 0], + 'infer': StridedSlice.infer}), + + **regular_op_with_empty_data('strided_slice_ref', {'op': 'StridedSlice', + 'type': 'StridedSlice', + 'begin_mask': [1, 1, 1, 0], + 'end_mask': [1, 1, 1, 0], + 'new_axis_mask': [0, 0, 0, 0], + 'shrink_axis_mask': [1, 0, 0, 0], + 'ellipsis_mask': [0, 0, 0, 0], + 'infer': StridedSlice.infer}), + **valued_const_with_data('begin', int64_array([0, 0, 0])), + **valued_const_with_data('begin_placeholder', int64_array([0])), + **regular_op_with_empty_data('begin_concat', + {'op': 'Concat', 'infer': concat_infer, 'axis': 0, 'dim_attrs': {}}), + **valued_const_with_data('end', int64_array([4, 25, 50])), + **valued_const_with_data('end_placeholder', int64_array([0])), + **regular_op_with_empty_data('end_concat', + {'op': 'Concat', 'infer': concat_infer, 'axis': 0, 'dim_attrs': {}}), + **regular_op('res', {'kind': 'op', 'type': 'Result', 'op': 'Result', 'infer': lambda x: None}) + } + + edges_ref_extended_inputs = ( + *connect('input', '0:strided_slice_ref'), + + *connect('begin', '0:begin_concat'), + *connect('begin_placeholder', '1:begin_concat'), + *connect('begin_concat', '1:strided_slice_ref'), + + *connect('end', '0:end_concat'), + *connect('end_placeholder', '1:end_concat'), + *connect('end_concat', '2:strided_slice_ref'), + + *connect('strided_slice_ref', 'res') + ) + + graph = build_graph(nodes, edges_without_strides, nodes_with_edges_only=True) + graph_ref = build_graph(nodes, edges_ref_extended_inputs, nodes_with_edges_only=True) + graph.stage = 'middle' + graph_ref.stage = 'middle' + + graph = partial_infer(graph) + StridedSliceNormalizer().find_and_replace_pattern(graph) + graph = partial_infer(graph) + graph_ref = partial_infer(graph_ref) + + (flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=False) + self.assertTrue(flag, 'Graphs after StridedSliceNormalizer do not match to reference: {}'.format(resp)) + def test_strided_slice_unrooll_ellipsis(self): input_shape = (10, 10, 10, 10) # out = inp[1:4, ..., 0:5] -> inp[1:4, :, :, 0:5] => out_shape = (3, 10, 10, 5) @@ -204,6 +269,84 @@ class TestStridedSliceNormalizer(unittest.TestCase): (flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=False) self.assertTrue(flag, 'Graphs after StridedSliceNormalizer do not match to reference: {}'.format(resp)) + def test_strided_slice_unrooll_ellipsis_without_strides(self): + input_shape = (10, 10, 10, 10) + # out = inp[1:4, ..., 0:5] -> inp[1:4, :, :, 0:5] => out_shape = (3, 10, 10, 5) + ellipsis_start = 1 + + nodes = { + **valued_const_with_data('input', np.arange(np.product(input_shape)).reshape(*input_shape)), + **regular_op_with_empty_data('strided_slice', {'op': 'StridedSlice', 'type': 'StridedSlice', + 'begin_mask': [1, 1, 1], 'end_mask': [1, 1, 1], + 'shrink_axis_mask': [0, 0, 0], + 'new_axis_mask': [0, 0, 0], + 'ellipsis_mask': [0, 1, 0], + 'infer': StridedSlice.infer}), + + **regular_op_with_empty_data('strided_slice_ref', {'op': 'StridedSlice', 'begin_mask': [1, 0, 0, 1], + 'end_mask': [1, 0, 0, 1], 'ellipsis_mask': [0, 0, 0, 0], + 'new_axis_mask': [0, 0, 0, 0], + 'shrink_axis_mask': [0, 0, 0, 0], + 'infer': StridedSlice.infer}), + + **valued_const_with_data('begin', int64_array([1, 0, 0])), + **valued_const_with_data('split_axis_begin', int64_array(0)), + **valued_const_with_data('splits_lengths_begin', int64_array([ellipsis_start, -1])), + **regular_op_with_empty_data('split_for_begin', {'op': 'VariadicSplit', 'infer': VariadicSplit.infer}), + **empty_data('split_for_begin_data_1'), + **valued_const_with_data('begin_placeholder', int64_array([0])), + **regular_op_with_empty_data('begin_concat', + {'op': 'Concat', 'infer': concat_infer, 'axis': 0, 'dim_attrs': {}}), + + + **valued_const_with_data('end', int64_array([4, 0, 5])), + **valued_const_with_data('split_axis_end', int64_array(0)), + **valued_const_with_data('splits_lengths_end', int64_array([ellipsis_start, -1])), + **regular_op_with_empty_data('split_for_end', {'op': 'VariadicSplit', 'infer': VariadicSplit.infer}), + **empty_data('split_for_end_data_1'), + **valued_const_with_data('end_placeholder', int64_array([0])), + **regular_op_with_empty_data('end_concat', + {'op': 'Concat', 'infer': concat_infer, 'axis': 0, 'dim_attrs': {}}), + + **regular_op('res', {'kind': 'op', 'type': 'Result', 'op': 'Result', 'infer': lambda x: None}) + } + + edges_ref_ellipsis_unrolled = ( + *connect('input', '0:strided_slice_ref'), + + *connect('begin', '0:split_for_begin'), + *connect('split_axis_begin', '1:split_for_begin'), + *connect('splits_lengths_begin', '2:split_for_begin'), + *connect('split_for_begin:0', '0:begin_concat'), + *connect('begin_placeholder', '1:begin_concat'), + ('split_for_begin', 'split_for_begin_data_1', {'out': 1, 'in': 2}), + ('split_for_begin_data_1', 'begin_concat', {'out': 1, 'in': 2}), + *connect('begin_concat', '1:strided_slice_ref'), + + *connect('end', '0:split_for_end'), + *connect('split_axis_end', '1:split_for_end'), + *connect('splits_lengths_end', '2:split_for_end'), + *connect('split_for_end:0', '0:end_concat'), + *connect('end_placeholder', '1:end_concat'), + ('split_for_end', 'split_for_end_data_1', {'out': 1, 'in': 2}), + ('split_for_end_data_1', 'end_concat', {'out': 1, 'in': 2}), + *connect('end_concat', '2:strided_slice_ref'), + + *connect('strided_slice_ref', 'res') + ) + + graph = build_graph(nodes, edges_without_strides, nodes_with_edges_only=True) + graph_ref = build_graph(nodes, edges_ref_ellipsis_unrolled, nodes_with_edges_only=True) + graph.stage = 'middle' + graph_ref.stage = 'middle' + graph = partial_infer(graph) + StridedSliceNormalizer().find_and_replace_pattern(graph) + graph = partial_infer(graph) + graph_ref = partial_infer(graph_ref) + + (flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=False) + self.assertTrue(flag, 'Graphs after StridedSliceNormalizer do not match to reference: {}'.format(resp)) + class TestStridedSliceShapeInferAfterNormalizer(unittest.TestCase): # check that after inserting Splits and Concats we still get the same shape diff --git a/model-optimizer/mo/front/common/partial_infer/utils.py b/model-optimizer/mo/front/common/partial_infer/utils.py index 625e7c97bbb..c367610815d 100644 --- a/model-optimizer/mo/front/common/partial_infer/utils.py +++ b/model-optimizer/mo/front/common/partial_infer/utils.py @@ -132,7 +132,7 @@ def get_shape_from_slice(input_shape: np.ndarray, slices: List) -> np.ndarray: in_idx += 1 elif s is np.newaxis: output_shape.append(1) - elif isinstance(s, int): # shrink_axis + elif type(s) in [int, np.int, np.int32, np.int64]: # shrink_axis in_idx += 1 elif s is Ellipsis: for idx in range(num_ellipsis_inserts):