[MO] allow zero size dimensions in Slice output (#7791)
* allow zero size dimensions in Slice output * updated comments regarding 0-size dimensions
This commit is contained in:
parent
21090f47b2
commit
9f64f77a3c
@ -142,8 +142,6 @@ class Slice(Op):
|
||||
slice_idx[axes[i]] = slice(starts[i], ends[i], steps[i])
|
||||
if input_value is None or any(is_dynamic_slice(s) for s in slice_idx):
|
||||
output_shape = get_shape_from_slice(input_shape, slice_idx)
|
||||
if np.ma.any(output_shape <= 0):
|
||||
raise Error('Output shape: {} of node "{}" contains non-positive values'.format(output_shape, node.name))
|
||||
node.out_port(0).data.set_shape(output_shape)
|
||||
else:
|
||||
node.out_port(0).data.set_value(input_value[tuple(slice_idx)])
|
||||
|
@ -57,6 +57,12 @@ class TestSliceOp(unittest.TestCase):
|
||||
# steps are non-constant
|
||||
([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], [0, 1], [3, -2], [0, 1], None, None,
|
||||
[dynamic_dimension_value, dynamic_dimension_value]),
|
||||
# negative steps and since after normalization starts < ends output shape has 0-size dimension
|
||||
(None, [20], [1], [-1], [0], [-2], None, [0]),
|
||||
# since starts == ends output shape has 0-size dimension
|
||||
(None, [4], [1], [1], [0], [1], None, [0]),
|
||||
# since starts > ends output shape has 0-size dimension
|
||||
(None, [4], [2], [1], [0], [1], None, [0])
|
||||
])
|
||||
def test_slice_infer(self, inp_value, inp_shape, starts, ends, axes, steps, expected_value, expected_shape):
|
||||
if inp_value is None:
|
||||
@ -103,44 +109,3 @@ class TestSliceOp(unittest.TestCase):
|
||||
if expected_value is not None:
|
||||
self.assertTrue(strict_compare_tensors(slice_node.out_node().value, expected_value))
|
||||
self.assertTrue(strict_compare_tensors(slice_node.out_node().shape, expected_shape))
|
||||
|
||||
# negative tests
|
||||
@generate(*[
|
||||
# 1D input with negative starts
|
||||
(None, [1], [-1], [0], [-2], [-6], [20]),
|
||||
# case when output shape has zero elements
|
||||
(None, [1], [1], [0], [1], [0], [4])
|
||||
])
|
||||
def test_slice_infer_negative(self, inp_value, starts, ends, axes, steps, expected, inp_shape=None):
|
||||
if inp_value is None:
|
||||
input_node = shaped_data('data_1', int64_array(inp_shape))
|
||||
else:
|
||||
input_node = valued_data('data_1', int64_array(inp_value))
|
||||
|
||||
def convert_args(val, name=''):
|
||||
if val is not None:
|
||||
return valued_const_with_data(name, int64_array(val))
|
||||
else:
|
||||
return shaped_const_with_data(name, [0]) #fake shape
|
||||
|
||||
starts = convert_args(starts, 'starts')
|
||||
ends = convert_args(ends, 'ends')
|
||||
axes = convert_args(axes, 'axes')
|
||||
steps = convert_args(steps, 'steps')
|
||||
|
||||
nodes = {**input_node,
|
||||
**regular_op_with_empty_data('slice', {'op': 'Slice'}),
|
||||
**starts, **ends, **axes, **steps
|
||||
}
|
||||
|
||||
graph = build_graph(nodes,
|
||||
[('data_1', 'slice'),
|
||||
*connect('starts', '1:slice'),
|
||||
*connect('ends', '2:slice'),
|
||||
*connect('axes', '3:slice'),
|
||||
*connect('steps', '4:slice'),
|
||||
*connect('slice', 'slice_d')])
|
||||
|
||||
graph.stage = 'middle'
|
||||
slice_node = Node(graph, 'slice')
|
||||
self.assertRaises(Error, Slice.infer, slice_node)
|
||||
|
Loading…
Reference in New Issue
Block a user