[MO] allow zero size dimensions in Slice output (#7791)

* allow zero size dimensions in Slice output

* updated comments regarding 0-size dimensions
This commit is contained in:
Pavel Esir 2021-10-20 13:16:06 +03:00 committed by GitHub
parent 21090f47b2
commit 9f64f77a3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 43 deletions

View File

@ -142,8 +142,6 @@ class Slice(Op):
slice_idx[axes[i]] = slice(starts[i], ends[i], steps[i])
if input_value is None or any(is_dynamic_slice(s) for s in slice_idx):
output_shape = get_shape_from_slice(input_shape, slice_idx)
if np.ma.any(output_shape <= 0):
raise Error('Output shape: {} of node "{}" contains non-positive values'.format(output_shape, node.name))
node.out_port(0).data.set_shape(output_shape)
else:
node.out_port(0).data.set_value(input_value[tuple(slice_idx)])

View File

@ -57,6 +57,12 @@ class TestSliceOp(unittest.TestCase):
# steps are non-constant
([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], [0, 1], [3, -2], [0, 1], None, None,
[dynamic_dimension_value, dynamic_dimension_value]),
# negative steps and since after normalization starts < ends output shape has 0-size dimension
(None, [20], [1], [-1], [0], [-2], None, [0]),
# since starts == ends output shape has 0-size dimension
(None, [4], [1], [1], [0], [1], None, [0]),
# since starts > ends output shape has 0-size dimension
(None, [4], [2], [1], [0], [1], None, [0])
])
def test_slice_infer(self, inp_value, inp_shape, starts, ends, axes, steps, expected_value, expected_shape):
if inp_value is None:
@ -103,44 +109,3 @@ class TestSliceOp(unittest.TestCase):
if expected_value is not None:
self.assertTrue(strict_compare_tensors(slice_node.out_node().value, expected_value))
self.assertTrue(strict_compare_tensors(slice_node.out_node().shape, expected_shape))
# negative tests
@generate(*[
# 1D input with negative starts
(None, [1], [-1], [0], [-2], [-6], [20]),
# case when output shape has zero elements
(None, [1], [1], [0], [1], [0], [4])
])
def test_slice_infer_negative(self, inp_value, starts, ends, axes, steps, expected, inp_shape=None):
if inp_value is None:
input_node = shaped_data('data_1', int64_array(inp_shape))
else:
input_node = valued_data('data_1', int64_array(inp_value))
def convert_args(val, name=''):
if val is not None:
return valued_const_with_data(name, int64_array(val))
else:
return shaped_const_with_data(name, [0]) #fake shape
starts = convert_args(starts, 'starts')
ends = convert_args(ends, 'ends')
axes = convert_args(axes, 'axes')
steps = convert_args(steps, 'steps')
nodes = {**input_node,
**regular_op_with_empty_data('slice', {'op': 'Slice'}),
**starts, **ends, **axes, **steps
}
graph = build_graph(nodes,
[('data_1', 'slice'),
*connect('starts', '1:slice'),
*connect('ends', '2:slice'),
*connect('axes', '3:slice'),
*connect('steps', '4:slice'),
*connect('slice', 'slice_d')])
graph.stage = 'middle'
slice_node = Node(graph, 'slice')
self.assertRaises(Error, Slice.infer, slice_node)