Fix StridedSlice replacer order and input permutation when strides are not specified (#4545)
This commit is contained in:
parent
92d750747c
commit
29612f15e3
@ -102,6 +102,10 @@ class StridedSliceNormalizer(MiddleReplacementPattern):
|
||||
from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths
|
||||
return [LayoutChangeForConstantShapePaths]
|
||||
|
||||
def run_after(self):
|
||||
from extensions.middle.SliceConverter import ConvertSlice
|
||||
return [ConvertSlice]
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for node in graph.get_op_nodes(type='StridedSlice'):
|
||||
StridedSliceNormalizer.normalize_strided_slice(graph, node)
|
||||
@ -116,7 +120,8 @@ class StridedSliceNormalizer(MiddleReplacementPattern):
|
||||
# Until now it was not possible to set correct permutations
|
||||
PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:1', 'slice', 'dim_size')
|
||||
PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:2', 'slice', 'dim_size')
|
||||
PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:3', 'slice', 'dim_size')
|
||||
if node.is_in_port_connected(3):
|
||||
PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:3', 'slice', 'dim_size')
|
||||
|
||||
@staticmethod
|
||||
def normalize_strided_slice(graph: Graph, node: Node):
|
||||
@ -157,13 +162,13 @@ class StridedSliceNormalizer(MiddleReplacementPattern):
|
||||
node_name = node.soft_get('name', node.id)
|
||||
|
||||
for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]:
|
||||
if i == 3 and not node.is_in_port_connected(3):
|
||||
continue # no need to extend strides if they are not connected
|
||||
|
||||
blank_values_arr = np.zeros(num_insertions) if input_name != 'strides' else np.ones(num_insertions)
|
||||
blank_values_node = Const(graph, {'name': node_name + '/const_to_unroll_{}_ellipsis'.format(input_name),
|
||||
'value': int64_array(blank_values_arr)}).create_node()
|
||||
|
||||
if i == 3 and node.in_port(3).disconnected():
|
||||
continue # no need to extend strides if they are not connected
|
||||
|
||||
concat_in_ports_count = 3 if ellipsis_start != 0 else 2
|
||||
concat = Concat(graph, {'axis': 0, 'name': node_name + '/concat_{}'.format(input_name),
|
||||
'in_ports_count': concat_in_ports_count}).create_node()
|
||||
@ -190,13 +195,13 @@ class StridedSliceNormalizer(MiddleReplacementPattern):
|
||||
node_name = node.soft_get('name', node.id)
|
||||
|
||||
for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]:
|
||||
if i == 3 and not node.is_in_port_connected(3):
|
||||
continue # no need to extend strides if they are not connected
|
||||
|
||||
blank_values_arr = np.zeros(num_insertions) if input_name != 'strides' else np.ones(num_insertions)
|
||||
blank_values_node = Const(graph, {'name': node_name + '/extend_{}_const'.format(input_name),
|
||||
'value': int64_array(blank_values_arr)}).create_node()
|
||||
|
||||
if i == 3 and node.in_port(3).disconnected():
|
||||
continue # no need to extend strides if they are not connected
|
||||
|
||||
if node.in_port(i).get_source().node.soft_get('type') == 'Concat':
|
||||
# concat already exists
|
||||
concat = node.in_port(i).get_source().node
|
||||
@ -227,7 +232,7 @@ class StridedSliceNormalizer(MiddleReplacementPattern):
|
||||
if strides is None:
|
||||
raise Error('StridedSlice operation for node {} supports only constant strides input'.format(node_name))
|
||||
else:
|
||||
strides = np.ones(slice_rank)
|
||||
strides = np.ones(len(node['slices']), dtype=np.int32)
|
||||
|
||||
num_ellipsis_inserts = len(data_shape) - slice_rank + np.count_nonzero(node.new_axis_mask) + 1
|
||||
res_slices = []
|
||||
|
@ -38,6 +38,12 @@ edges = (
|
||||
*connect('strided_slice', 'res')
|
||||
)
|
||||
|
||||
edges_without_strides = (
|
||||
*connect('input', '0:strided_slice'),
|
||||
*connect('begin', '1:strided_slice'),
|
||||
*connect('end', '2:strided_slice'),
|
||||
*connect('strided_slice', 'res')
|
||||
)
|
||||
|
||||
class TestStridedSliceNormalizer(unittest.TestCase):
|
||||
|
||||
@ -108,6 +114,65 @@ class TestStridedSliceNormalizer(unittest.TestCase):
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=False)
|
||||
self.assertTrue(flag, 'Graphs after StridedSliceNormalizer do not match to reference: {}'.format(resp))
|
||||
|
||||
def test_strided_slice_extend_inputs_without_strides(self):
|
||||
input_shape = (16, 100, 100, 3)
|
||||
nodes = {
|
||||
**valued_const_with_data('input', np.arange(np.product(input_shape)).reshape(*input_shape)),
|
||||
**regular_op_with_empty_data('strided_slice', {'op': 'StridedSlice',
|
||||
'type': 'StridedSlice',
|
||||
'begin_mask': [1, 1, 1],
|
||||
'end_mask': [1, 1, 1],
|
||||
'shrink_axis_mask': [1, 0, 0],
|
||||
'new_axis_mask': [0, 0, 0],
|
||||
'ellipsis_mask': [0, 0, 0],
|
||||
'infer': StridedSlice.infer}),
|
||||
|
||||
**regular_op_with_empty_data('strided_slice_ref', {'op': 'StridedSlice',
|
||||
'type': 'StridedSlice',
|
||||
'begin_mask': [1, 1, 1, 0],
|
||||
'end_mask': [1, 1, 1, 0],
|
||||
'new_axis_mask': [0, 0, 0, 0],
|
||||
'shrink_axis_mask': [1, 0, 0, 0],
|
||||
'ellipsis_mask': [0, 0, 0, 0],
|
||||
'infer': StridedSlice.infer}),
|
||||
**valued_const_with_data('begin', int64_array([0, 0, 0])),
|
||||
**valued_const_with_data('begin_placeholder', int64_array([0])),
|
||||
**regular_op_with_empty_data('begin_concat',
|
||||
{'op': 'Concat', 'infer': concat_infer, 'axis': 0, 'dim_attrs': {}}),
|
||||
**valued_const_with_data('end', int64_array([4, 25, 50])),
|
||||
**valued_const_with_data('end_placeholder', int64_array([0])),
|
||||
**regular_op_with_empty_data('end_concat',
|
||||
{'op': 'Concat', 'infer': concat_infer, 'axis': 0, 'dim_attrs': {}}),
|
||||
**regular_op('res', {'kind': 'op', 'type': 'Result', 'op': 'Result', 'infer': lambda x: None})
|
||||
}
|
||||
|
||||
edges_ref_extended_inputs = (
|
||||
*connect('input', '0:strided_slice_ref'),
|
||||
|
||||
*connect('begin', '0:begin_concat'),
|
||||
*connect('begin_placeholder', '1:begin_concat'),
|
||||
*connect('begin_concat', '1:strided_slice_ref'),
|
||||
|
||||
*connect('end', '0:end_concat'),
|
||||
*connect('end_placeholder', '1:end_concat'),
|
||||
*connect('end_concat', '2:strided_slice_ref'),
|
||||
|
||||
*connect('strided_slice_ref', 'res')
|
||||
)
|
||||
|
||||
graph = build_graph(nodes, edges_without_strides, nodes_with_edges_only=True)
|
||||
graph_ref = build_graph(nodes, edges_ref_extended_inputs, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
graph_ref.stage = 'middle'
|
||||
|
||||
graph = partial_infer(graph)
|
||||
StridedSliceNormalizer().find_and_replace_pattern(graph)
|
||||
graph = partial_infer(graph)
|
||||
graph_ref = partial_infer(graph_ref)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=False)
|
||||
self.assertTrue(flag, 'Graphs after StridedSliceNormalizer do not match to reference: {}'.format(resp))
|
||||
|
||||
def test_strided_slice_unrooll_ellipsis(self):
|
||||
input_shape = (10, 10, 10, 10)
|
||||
# out = inp[1:4, ..., 0:5] -> inp[1:4, :, :, 0:5] => out_shape = (3, 10, 10, 5)
|
||||
@ -204,6 +269,84 @@ class TestStridedSliceNormalizer(unittest.TestCase):
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=False)
|
||||
self.assertTrue(flag, 'Graphs after StridedSliceNormalizer do not match to reference: {}'.format(resp))
|
||||
|
||||
def test_strided_slice_unrooll_ellipsis_without_strides(self):
|
||||
input_shape = (10, 10, 10, 10)
|
||||
# out = inp[1:4, ..., 0:5] -> inp[1:4, :, :, 0:5] => out_shape = (3, 10, 10, 5)
|
||||
ellipsis_start = 1
|
||||
|
||||
nodes = {
|
||||
**valued_const_with_data('input', np.arange(np.product(input_shape)).reshape(*input_shape)),
|
||||
**regular_op_with_empty_data('strided_slice', {'op': 'StridedSlice', 'type': 'StridedSlice',
|
||||
'begin_mask': [1, 1, 1], 'end_mask': [1, 1, 1],
|
||||
'shrink_axis_mask': [0, 0, 0],
|
||||
'new_axis_mask': [0, 0, 0],
|
||||
'ellipsis_mask': [0, 1, 0],
|
||||
'infer': StridedSlice.infer}),
|
||||
|
||||
**regular_op_with_empty_data('strided_slice_ref', {'op': 'StridedSlice', 'begin_mask': [1, 0, 0, 1],
|
||||
'end_mask': [1, 0, 0, 1], 'ellipsis_mask': [0, 0, 0, 0],
|
||||
'new_axis_mask': [0, 0, 0, 0],
|
||||
'shrink_axis_mask': [0, 0, 0, 0],
|
||||
'infer': StridedSlice.infer}),
|
||||
|
||||
**valued_const_with_data('begin', int64_array([1, 0, 0])),
|
||||
**valued_const_with_data('split_axis_begin', int64_array(0)),
|
||||
**valued_const_with_data('splits_lengths_begin', int64_array([ellipsis_start, -1])),
|
||||
**regular_op_with_empty_data('split_for_begin', {'op': 'VariadicSplit', 'infer': VariadicSplit.infer}),
|
||||
**empty_data('split_for_begin_data_1'),
|
||||
**valued_const_with_data('begin_placeholder', int64_array([0])),
|
||||
**regular_op_with_empty_data('begin_concat',
|
||||
{'op': 'Concat', 'infer': concat_infer, 'axis': 0, 'dim_attrs': {}}),
|
||||
|
||||
|
||||
**valued_const_with_data('end', int64_array([4, 0, 5])),
|
||||
**valued_const_with_data('split_axis_end', int64_array(0)),
|
||||
**valued_const_with_data('splits_lengths_end', int64_array([ellipsis_start, -1])),
|
||||
**regular_op_with_empty_data('split_for_end', {'op': 'VariadicSplit', 'infer': VariadicSplit.infer}),
|
||||
**empty_data('split_for_end_data_1'),
|
||||
**valued_const_with_data('end_placeholder', int64_array([0])),
|
||||
**regular_op_with_empty_data('end_concat',
|
||||
{'op': 'Concat', 'infer': concat_infer, 'axis': 0, 'dim_attrs': {}}),
|
||||
|
||||
**regular_op('res', {'kind': 'op', 'type': 'Result', 'op': 'Result', 'infer': lambda x: None})
|
||||
}
|
||||
|
||||
edges_ref_ellipsis_unrolled = (
|
||||
*connect('input', '0:strided_slice_ref'),
|
||||
|
||||
*connect('begin', '0:split_for_begin'),
|
||||
*connect('split_axis_begin', '1:split_for_begin'),
|
||||
*connect('splits_lengths_begin', '2:split_for_begin'),
|
||||
*connect('split_for_begin:0', '0:begin_concat'),
|
||||
*connect('begin_placeholder', '1:begin_concat'),
|
||||
('split_for_begin', 'split_for_begin_data_1', {'out': 1, 'in': 2}),
|
||||
('split_for_begin_data_1', 'begin_concat', {'out': 1, 'in': 2}),
|
||||
*connect('begin_concat', '1:strided_slice_ref'),
|
||||
|
||||
*connect('end', '0:split_for_end'),
|
||||
*connect('split_axis_end', '1:split_for_end'),
|
||||
*connect('splits_lengths_end', '2:split_for_end'),
|
||||
*connect('split_for_end:0', '0:end_concat'),
|
||||
*connect('end_placeholder', '1:end_concat'),
|
||||
('split_for_end', 'split_for_end_data_1', {'out': 1, 'in': 2}),
|
||||
('split_for_end_data_1', 'end_concat', {'out': 1, 'in': 2}),
|
||||
*connect('end_concat', '2:strided_slice_ref'),
|
||||
|
||||
*connect('strided_slice_ref', 'res')
|
||||
)
|
||||
|
||||
graph = build_graph(nodes, edges_without_strides, nodes_with_edges_only=True)
|
||||
graph_ref = build_graph(nodes, edges_ref_ellipsis_unrolled, nodes_with_edges_only=True)
|
||||
graph.stage = 'middle'
|
||||
graph_ref.stage = 'middle'
|
||||
graph = partial_infer(graph)
|
||||
StridedSliceNormalizer().find_and_replace_pattern(graph)
|
||||
graph = partial_infer(graph)
|
||||
graph_ref = partial_infer(graph_ref)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=False)
|
||||
self.assertTrue(flag, 'Graphs after StridedSliceNormalizer do not match to reference: {}'.format(resp))
|
||||
|
||||
|
||||
class TestStridedSliceShapeInferAfterNormalizer(unittest.TestCase):
|
||||
# check that after inserting Splits and Concats we still get the same shape
|
||||
|
@ -132,7 +132,7 @@ def get_shape_from_slice(input_shape: np.ndarray, slices: List) -> np.ndarray:
|
||||
in_idx += 1
|
||||
elif s is np.newaxis:
|
||||
output_shape.append(1)
|
||||
elif isinstance(s, int): # shrink_axis
|
||||
elif type(s) in [int, np.int, np.int32, np.int64]: # shrink_axis
|
||||
in_idx += 1
|
||||
elif s is Ellipsis:
|
||||
for idx in range(num_ellipsis_inserts):
|
||||
|
Loading…
Reference in New Issue
Block a user