Fix StridedSlice replacer order and input permutation when strides are not specified (#4545)

This commit is contained in:
Pavel Esir 2021-03-03 11:16:56 +03:00 committed by GitHub
parent 92d750747c
commit 29612f15e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 157 additions and 9 deletions

View File

@ -102,6 +102,10 @@ class StridedSliceNormalizer(MiddleReplacementPattern):
from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths
return [LayoutChangeForConstantShapePaths]
def run_after(self):
from extensions.middle.SliceConverter import ConvertSlice
return [ConvertSlice]
def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(type='StridedSlice'):
StridedSliceNormalizer.normalize_strided_slice(graph, node)
@ -116,7 +120,8 @@ class StridedSliceNormalizer(MiddleReplacementPattern):
# Until now it was not possible to set correct permutations
PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:1', 'slice', 'dim_size')
PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:2', 'slice', 'dim_size')
PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:3', 'slice', 'dim_size')
if node.is_in_port_connected(3):
PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:3', 'slice', 'dim_size')
@staticmethod
def normalize_strided_slice(graph: Graph, node: Node):
@ -157,13 +162,13 @@ class StridedSliceNormalizer(MiddleReplacementPattern):
node_name = node.soft_get('name', node.id)
for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]:
if i == 3 and not node.is_in_port_connected(3):
continue # no need to extend strides if they are not connected
blank_values_arr = np.zeros(num_insertions) if input_name != 'strides' else np.ones(num_insertions)
blank_values_node = Const(graph, {'name': node_name + '/const_to_unroll_{}_ellipsis'.format(input_name),
'value': int64_array(blank_values_arr)}).create_node()
if i == 3 and node.in_port(3).disconnected():
continue # no need to extend strides if they are not connected
concat_in_ports_count = 3 if ellipsis_start != 0 else 2
concat = Concat(graph, {'axis': 0, 'name': node_name + '/concat_{}'.format(input_name),
'in_ports_count': concat_in_ports_count}).create_node()
@ -190,13 +195,13 @@ class StridedSliceNormalizer(MiddleReplacementPattern):
node_name = node.soft_get('name', node.id)
for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]:
if i == 3 and not node.is_in_port_connected(3):
continue # no need to extend strides if they are not connected
blank_values_arr = np.zeros(num_insertions) if input_name != 'strides' else np.ones(num_insertions)
blank_values_node = Const(graph, {'name': node_name + '/extend_{}_const'.format(input_name),
'value': int64_array(blank_values_arr)}).create_node()
if i == 3 and node.in_port(3).disconnected():
continue # no need to extend strides if they are not connected
if node.in_port(i).get_source().node.soft_get('type') == 'Concat':
# concat already exists
concat = node.in_port(i).get_source().node
@ -227,7 +232,7 @@ class StridedSliceNormalizer(MiddleReplacementPattern):
if strides is None:
raise Error('StridedSlice operation for node {} supports only constant strides input'.format(node_name))
else:
strides = np.ones(slice_rank)
strides = np.ones(len(node['slices']), dtype=np.int32)
num_ellipsis_inserts = len(data_shape) - slice_rank + np.count_nonzero(node.new_axis_mask) + 1
res_slices = []

View File

@ -38,6 +38,12 @@ edges = (
*connect('strided_slice', 'res')
)
edges_without_strides = (
*connect('input', '0:strided_slice'),
*connect('begin', '1:strided_slice'),
*connect('end', '2:strided_slice'),
*connect('strided_slice', 'res')
)
class TestStridedSliceNormalizer(unittest.TestCase):
@ -108,6 +114,65 @@ class TestStridedSliceNormalizer(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=False)
self.assertTrue(flag, 'Graphs after StridedSliceNormalizer do not match to reference: {}'.format(resp))
def test_strided_slice_extend_inputs_without_strides(self):
input_shape = (16, 100, 100, 3)
nodes = {
**valued_const_with_data('input', np.arange(np.product(input_shape)).reshape(*input_shape)),
**regular_op_with_empty_data('strided_slice', {'op': 'StridedSlice',
'type': 'StridedSlice',
'begin_mask': [1, 1, 1],
'end_mask': [1, 1, 1],
'shrink_axis_mask': [1, 0, 0],
'new_axis_mask': [0, 0, 0],
'ellipsis_mask': [0, 0, 0],
'infer': StridedSlice.infer}),
**regular_op_with_empty_data('strided_slice_ref', {'op': 'StridedSlice',
'type': 'StridedSlice',
'begin_mask': [1, 1, 1, 0],
'end_mask': [1, 1, 1, 0],
'new_axis_mask': [0, 0, 0, 0],
'shrink_axis_mask': [1, 0, 0, 0],
'ellipsis_mask': [0, 0, 0, 0],
'infer': StridedSlice.infer}),
**valued_const_with_data('begin', int64_array([0, 0, 0])),
**valued_const_with_data('begin_placeholder', int64_array([0])),
**regular_op_with_empty_data('begin_concat',
{'op': 'Concat', 'infer': concat_infer, 'axis': 0, 'dim_attrs': {}}),
**valued_const_with_data('end', int64_array([4, 25, 50])),
**valued_const_with_data('end_placeholder', int64_array([0])),
**regular_op_with_empty_data('end_concat',
{'op': 'Concat', 'infer': concat_infer, 'axis': 0, 'dim_attrs': {}}),
**regular_op('res', {'kind': 'op', 'type': 'Result', 'op': 'Result', 'infer': lambda x: None})
}
edges_ref_extended_inputs = (
*connect('input', '0:strided_slice_ref'),
*connect('begin', '0:begin_concat'),
*connect('begin_placeholder', '1:begin_concat'),
*connect('begin_concat', '1:strided_slice_ref'),
*connect('end', '0:end_concat'),
*connect('end_placeholder', '1:end_concat'),
*connect('end_concat', '2:strided_slice_ref'),
*connect('strided_slice_ref', 'res')
)
graph = build_graph(nodes, edges_without_strides, nodes_with_edges_only=True)
graph_ref = build_graph(nodes, edges_ref_extended_inputs, nodes_with_edges_only=True)
graph.stage = 'middle'
graph_ref.stage = 'middle'
graph = partial_infer(graph)
StridedSliceNormalizer().find_and_replace_pattern(graph)
graph = partial_infer(graph)
graph_ref = partial_infer(graph_ref)
(flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=False)
self.assertTrue(flag, 'Graphs after StridedSliceNormalizer do not match to reference: {}'.format(resp))
def test_strided_slice_unrooll_ellipsis(self):
input_shape = (10, 10, 10, 10)
# out = inp[1:4, ..., 0:5] -> inp[1:4, :, :, 0:5] => out_shape = (3, 10, 10, 5)
@ -204,6 +269,84 @@ class TestStridedSliceNormalizer(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=False)
self.assertTrue(flag, 'Graphs after StridedSliceNormalizer do not match to reference: {}'.format(resp))
def test_strided_slice_unrooll_ellipsis_without_strides(self):
input_shape = (10, 10, 10, 10)
# out = inp[1:4, ..., 0:5] -> inp[1:4, :, :, 0:5] => out_shape = (3, 10, 10, 5)
ellipsis_start = 1
nodes = {
**valued_const_with_data('input', np.arange(np.product(input_shape)).reshape(*input_shape)),
**regular_op_with_empty_data('strided_slice', {'op': 'StridedSlice', 'type': 'StridedSlice',
'begin_mask': [1, 1, 1], 'end_mask': [1, 1, 1],
'shrink_axis_mask': [0, 0, 0],
'new_axis_mask': [0, 0, 0],
'ellipsis_mask': [0, 1, 0],
'infer': StridedSlice.infer}),
**regular_op_with_empty_data('strided_slice_ref', {'op': 'StridedSlice', 'begin_mask': [1, 0, 0, 1],
'end_mask': [1, 0, 0, 1], 'ellipsis_mask': [0, 0, 0, 0],
'new_axis_mask': [0, 0, 0, 0],
'shrink_axis_mask': [0, 0, 0, 0],
'infer': StridedSlice.infer}),
**valued_const_with_data('begin', int64_array([1, 0, 0])),
**valued_const_with_data('split_axis_begin', int64_array(0)),
**valued_const_with_data('splits_lengths_begin', int64_array([ellipsis_start, -1])),
**regular_op_with_empty_data('split_for_begin', {'op': 'VariadicSplit', 'infer': VariadicSplit.infer}),
**empty_data('split_for_begin_data_1'),
**valued_const_with_data('begin_placeholder', int64_array([0])),
**regular_op_with_empty_data('begin_concat',
{'op': 'Concat', 'infer': concat_infer, 'axis': 0, 'dim_attrs': {}}),
**valued_const_with_data('end', int64_array([4, 0, 5])),
**valued_const_with_data('split_axis_end', int64_array(0)),
**valued_const_with_data('splits_lengths_end', int64_array([ellipsis_start, -1])),
**regular_op_with_empty_data('split_for_end', {'op': 'VariadicSplit', 'infer': VariadicSplit.infer}),
**empty_data('split_for_end_data_1'),
**valued_const_with_data('end_placeholder', int64_array([0])),
**regular_op_with_empty_data('end_concat',
{'op': 'Concat', 'infer': concat_infer, 'axis': 0, 'dim_attrs': {}}),
**regular_op('res', {'kind': 'op', 'type': 'Result', 'op': 'Result', 'infer': lambda x: None})
}
edges_ref_ellipsis_unrolled = (
*connect('input', '0:strided_slice_ref'),
*connect('begin', '0:split_for_begin'),
*connect('split_axis_begin', '1:split_for_begin'),
*connect('splits_lengths_begin', '2:split_for_begin'),
*connect('split_for_begin:0', '0:begin_concat'),
*connect('begin_placeholder', '1:begin_concat'),
('split_for_begin', 'split_for_begin_data_1', {'out': 1, 'in': 2}),
('split_for_begin_data_1', 'begin_concat', {'out': 1, 'in': 2}),
*connect('begin_concat', '1:strided_slice_ref'),
*connect('end', '0:split_for_end'),
*connect('split_axis_end', '1:split_for_end'),
*connect('splits_lengths_end', '2:split_for_end'),
*connect('split_for_end:0', '0:end_concat'),
*connect('end_placeholder', '1:end_concat'),
('split_for_end', 'split_for_end_data_1', {'out': 1, 'in': 2}),
('split_for_end_data_1', 'end_concat', {'out': 1, 'in': 2}),
*connect('end_concat', '2:strided_slice_ref'),
*connect('strided_slice_ref', 'res')
)
graph = build_graph(nodes, edges_without_strides, nodes_with_edges_only=True)
graph_ref = build_graph(nodes, edges_ref_ellipsis_unrolled, nodes_with_edges_only=True)
graph.stage = 'middle'
graph_ref.stage = 'middle'
graph = partial_infer(graph)
StridedSliceNormalizer().find_and_replace_pattern(graph)
graph = partial_infer(graph)
graph_ref = partial_infer(graph_ref)
(flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=False)
self.assertTrue(flag, 'Graphs after StridedSliceNormalizer do not match to reference: {}'.format(resp))
class TestStridedSliceShapeInferAfterNormalizer(unittest.TestCase):
# check that after inserting Splits and Concats we still get the same shape

View File

@ -132,7 +132,7 @@ def get_shape_from_slice(input_shape: np.ndarray, slices: List) -> np.ndarray:
in_idx += 1
elif s is np.newaxis:
output_shape.append(1)
elif isinstance(s, int): # shrink_axis
elif type(s) in [int, np.int, np.int32, np.int64]: # shrink_axis
in_idx += 1
elif s is Ellipsis:
for idx in range(num_ellipsis_inserts):