Fix Slice issue in MO IR Reader (#13784)

* Fix slice issue in MO IR Reader

* Add unit test

* Fix slice test
This commit is contained in:
Maxim Vafin 2022-11-07 08:22:22 +01:00 committed by GitHub
parent 7bf6faf4cb
commit f1d7647b8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 240 additions and 115 deletions

View File

@ -3,7 +3,8 @@
import numpy as np import numpy as np
from openvino.tools.mo.front.common.partial_infer.utils import get_shape_from_slice, shape_array, dynamic_dimension_value, \ from openvino.tools.mo.front.common.partial_infer.utils import get_shape_from_slice, shape_array, \
dynamic_dimension_value, \
dynamic_dimension, is_dynamic_slice dynamic_dimension, is_dynamic_slice
from openvino.tools.mo.graph.graph import Node, Graph from openvino.tools.mo.graph.graph import Node, Graph
from openvino.tools.mo.ops.op import Op from openvino.tools.mo.ops.op import Op
@ -56,7 +57,6 @@ class CaffeSlice(Op):
}, attrs) }, attrs)
class TFSlice(Op): class TFSlice(Op):
""" """
TFSlice differs from Slice in ONNX, Caffe and MXNet. TFSlice differs from Slice in ONNX, Caffe and MXNet.
@ -97,6 +97,37 @@ class MXSlice(Op):
}, attrs) }, attrs)
def slice_infer(node: Node, steps_idx: int, axes_idx: int):
input_value = node.in_port(0).data.get_value()
input_shape = node.in_port(0).data.get_shape()
starts = node.in_port(1).data.get_value()
ends = node.in_port(2).data.get_value()
if node.is_in_port_connected(steps_idx):
steps = node.in_port(steps_idx).data.get_value()
else:
steps = np.ones(len(starts), dtype=np.int64)
if node.is_in_port_connected(axes_idx):
axes = node.in_port(axes_idx).data.get_value()
else:
axes = [x for x in range(len(starts))]
if starts is None or ends is None or steps is None or axes is None:
node.out_port(0).data.set_shape(shape_array([dynamic_dimension_value] * len(input_shape)))
return
slice_idx = [slice(0, in_shape, 1) for in_shape in input_shape]
for i in range(len(axes)):
# Ranged for output value for specified axis
slice_idx[axes[i]] = slice(starts[i], ends[i], steps[i])
if input_value is None or any(is_dynamic_slice(s) for s in slice_idx):
output_shape = get_shape_from_slice(input_shape, slice_idx)
node.out_port(0).data.set_shape(output_shape)
else:
node.out_port(0).data.set_value(input_value[tuple(slice_idx)])
class Slice(Op): class Slice(Op):
""" """
Semantic of Slice is identical to Slice in ONNX opset >= 10. Semantic of Slice is identical to Slice in ONNX opset >= 10.
@ -117,31 +148,26 @@ class Slice(Op):
@staticmethod @staticmethod
def infer(node: Node): def infer(node: Node):
input_value = node.in_port(0).data.get_value() slice_infer(node, 4, 3)
input_shape = node.in_port(0).data.get_shape()
starts = node.in_port(1).data.get_value()
ends = node.in_port(2).data.get_value()
if node.is_in_port_connected(4):
steps = node.in_port(4).data.get_value()
else:
steps = np.ones(len(starts), dtype=np.int64)
if node.is_in_port_connected(3): class OvSlice(Op):
axes = node.in_port(3).data.get_value() """
else: Semantic of OvSlice is identical to Slice in Openvino opset8.
axes = [x for x in range(len(starts))] It is introduced for usage in MO IR Reader.
"""
op = 'OvSlice'
enabled = False
if starts is None or ends is None or steps is None or axes is None: def __init__(self, graph: Graph, attrs: dict = None):
node.out_port(0).data.set_shape(shape_array([dynamic_dimension_value] * len(input_shape))) super().__init__(graph, {
return 'type': None,
'op': self.op,
'in_ports_count': 5,
'out_ports_count': 1,
'infer': self.infer
}, attrs)
slice_idx = [slice(0, in_shape, 1) for in_shape in input_shape] @staticmethod
for i in range(len(axes)): def infer(node: Node):
# Ranged for output value for specified axis slice_infer(node, 3, 4)
slice_idx[axes[i]] = slice(starts[i], ends[i], steps[i])
if input_value is None or any(is_dynamic_slice(s) for s in slice_idx):
output_shape = get_shape_from_slice(input_shape, slice_idx)
node.out_port(0).data.set_shape(output_shape)
else:
node.out_port(0).data.set_value(input_value[tuple(slice_idx)])

View File

@ -27,6 +27,7 @@ from openvino.tools.mo.ops.pooling import Pooling
from openvino.tools.mo.ops.psroipooling import DeformablePSROIPoolingOp from openvino.tools.mo.ops.psroipooling import DeformablePSROIPoolingOp
from openvino.tools.mo.ops.scatter import Scatter from openvino.tools.mo.ops.scatter import Scatter
from openvino.tools.mo.ops.scatternd import ScatterNDBase from openvino.tools.mo.ops.scatternd import ScatterNDBase
from openvino.tools.mo.ops.slice import OvSlice
from openvino.tools.mo.ops.split import Split, VariadicSplit from openvino.tools.mo.ops.split import Split, VariadicSplit
from openvino.tools.mo.utils.class_registration import update_registration from openvino.tools.mo.utils.class_registration import update_registration
from openvino.tools.mo.utils.import_extensions import import_by_path from openvino.tools.mo.utils.import_extensions import import_by_path
@ -47,6 +48,7 @@ custom_ops = {
'MaxPool': Pooling, 'MaxPool': Pooling,
'Multiply': Mul, 'Multiply': Mul,
'Power': Pow, 'Power': Pow,
'Slice': OvSlice,
'Split': Split, 'Split': Split,
'Subtract': Sub, 'Subtract': Sub,
'VariadicSplit': VariadicSplit, 'VariadicSplit': VariadicSplit,

View File

@ -6,106 +6,203 @@ import unittest
import numpy as np import numpy as np
from generator import generator, generate from generator import generator, generate
from openvino.tools.mo.front.common.partial_infer.utils import int64_array, dynamic_dimension_value, shape_array, strict_compare_tensors from openvino.tools.mo.front.common.partial_infer.utils import int64_array, dynamic_dimension_value, shape_array, \
strict_compare_tensors
from openvino.tools.mo.graph.graph import Node from openvino.tools.mo.graph.graph import Node
from openvino.tools.mo.ops.slice import Slice from openvino.tools.mo.ops.slice import Slice, OvSlice
from openvino.tools.mo.utils.error import Error
from unit_tests.utils.graph import build_graph, valued_const_with_data, valued_data, regular_op_with_empty_data, \ from unit_tests.utils.graph import build_graph, valued_const_with_data, valued_data, regular_op_with_empty_data, \
connect, shaped_data, shaped_const_with_data connect, shaped_data, shaped_const_with_data
@generator @generator
class TestSliceOp(unittest.TestCase): class TestSliceOp(unittest.TestCase):
@generate(*[ @generate(*[
# standard case # standard case
([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], [0, 1], [3, 2], [0, 1], [1, 1], ([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], [0, 1], [3, 2], [0, 1], [1, 1],
[[5], [3], [6]], [3, 1]), [[5], [3], [6]], [3, 1]),
# negative bounds # negative bounds
([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], [0, 1], [3, -2], [0, 1], [1, 1], ([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], [0, 1], [3, -2], [0, 1], [1, 1],
[[5], [3], [6]], [3, 1]), [[5], [3], [6]], [3, 1]),
# unusual order of axes # unusual order of axes
([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], [0, 1], [3, -2], [1, 0], [1, 1], ([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], [0, 1], [3, -2], [1, 0], [1, 1],
[[2, 3, 5]], [1, 3]), [[2, 3, 5]], [1, 3]),
# when only input_shape is defined without values (one from bottom element is shape) # when only input_shape is defined without values (one from bottom element is shape)
(None, [4, 5, 6], [1, 2], [4, 3], [0, 1], [1, 1], None, [3, 1, 6]), (None, [4, 5, 6], [1, 2], [4, 3], [0, 1], [1, 1], None, [3, 1, 6]),
# boundary case # boundary case
(None, [4, 5, 6], [0, 2], [np.iinfo(np.int32).max, 3], [0, 1], [1, 1], None, [4, 1, 6]), (None, [4, 5, 6], [0, 2], [np.iinfo(np.int32).max, 3], [0, 1], [1, 1], None, [4, 1, 6]),
# boundary case # boundary case
(None, [4, 5, 6], [np.iinfo(np.int32).min, 2], [3, 3], [0, 1], [1, 1], None, [3, 1, 6],), (None, [4, 5, 6], [np.iinfo(np.int32).min, 2], [3, 3], [0, 1], [1, 1], None, [3, 1, 6],),
# 1D input # 1D input
([1, 3, 224, 224], [4], [1], [2], [0], [1], [3], [1]), ([1, 3, 224, 224], [4], [1], [2], [0], [1], [3], [1]),
# 1D input with negative starts # 1D input with negative starts
(None, [4], [-1], [1], [0], [-1], None, [2]), (None, [4], [-1], [1], [0], [-1], None, [2]),
# 1D input with negative ends # 1D input with negative ends
(None, [4], [1], [-1], [0], [1], None, [2]), (None, [4], [1], [-1], [0], [1], None, [2]),
# with rounding (e.g. take from 1st to 3rd with step 4 should give shape 1 not 0) # with rounding (e.g. take from 1st to 3rd with step 4 should give shape 1 not 0)
(None, [4], [1], [3], [0], [4], None, [1]), (None, [4], [1], [3], [0], [4], None, [1]),
# with rounding and negative steps (e.g. take from 1st to 3rd with step 4 should give shape 1 not 0) # with rounding and negative steps (e.g. take from 1st to 3rd with step 4 should give shape 1 not 0)
(None, [10], [7], [3], [0], [-7], None, [1]), (None, [10], [7], [3], [0], [-7], None, [1]),
# reversing the sequence of elements # reversing the sequence of elements
(None, [10], [-1], [np.iinfo(np.int32).min], [0], [-1], None, [10]), (None, [10], [-1], [np.iinfo(np.int32).min], [0], [-1], None, [10]),
# dynamic dimensions cases # dynamic dimensions cases
# starts are non-constant # starts are non-constant
([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], None, [3, 2], [0, 1], [1, 1], None, ([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], None, [3, 2], [0, 1], [1, 1], None,
[dynamic_dimension_value, dynamic_dimension_value]), [dynamic_dimension_value, dynamic_dimension_value]),
# ends are non-constant # ends are non-constant
([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], [0, 1], None, [0, 1], [1, 1], None, ([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], [0, 1], None, [0, 1], [1, 1], None,
[dynamic_dimension_value, dynamic_dimension_value]), [dynamic_dimension_value, dynamic_dimension_value]),
# axes are non-constant # axes are non-constant
([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], [0, 1], [3, -2], None, [1, 1], None, ([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], [0, 1], [3, -2], None, [1, 1], None,
[dynamic_dimension_value, dynamic_dimension_value]), [dynamic_dimension_value, dynamic_dimension_value]),
# steps are non-constant # steps are non-constant
([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], [0, 1], [3, -2], [0, 1], None, None, ([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], [0, 1], [3, -2], [0, 1], None, None,
[dynamic_dimension_value, dynamic_dimension_value]), [dynamic_dimension_value, dynamic_dimension_value]),
# negative steps and since after normalization starts < ends output shape has 0-size dimension # negative steps and since after normalization starts < ends output shape has 0-size dimension
(None, [20], [1], [-1], [0], [-2], None, [0]), (None, [20], [1], [-1], [0], [-2], None, [0]),
# since starts == ends output shape has 0-size dimension # since starts == ends output shape has 0-size dimension
(None, [4], [1], [1], [0], [1], None, [0]), (None, [4], [1], [1], [0], [1], None, [0]),
# since starts > ends output shape has 0-size dimension # since starts > ends output shape has 0-size dimension
(None, [4], [2], [1], [0], [1], None, [0]) (None, [4], [2], [1], [0], [1], None, [0])
]) ])
def test_slice_infer(self, inp_value, inp_shape, starts, ends, axes, steps, expected_value, expected_shape): def test_slice_infer(self, inp_value, inp_shape, starts, ends, axes, steps, expected_value, expected_shape):
if inp_value is None: if inp_value is None:
input_node = shaped_data('data_1', int64_array(inp_shape)) input_node = shaped_data('data_1', int64_array(inp_shape))
else:
input_node = valued_data('data_1', int64_array(inp_value))
if inp_value is not None and inp_shape is not None:
assert np.array_equal(np.array(inp_value).shape, inp_shape)
def convert_args(val, name=''):
if val is not None:
return valued_const_with_data(name, int64_array(val))
else: else:
input_node = valued_data('data_1', int64_array(inp_value)) return shaped_const_with_data(name, [0]) # fake shape
if inp_value is not None and inp_shape is not None:
assert np.array_equal(np.array(inp_value).shape, inp_shape)
def convert_args(val, name=''): starts = convert_args(starts, 'starts')
if val is not None: ends = convert_args(ends, 'ends')
return valued_const_with_data(name, int64_array(val)) axes = convert_args(axes, 'axes')
else: steps = convert_args(steps, 'steps')
return shaped_const_with_data(name, [0]) #fake shape if expected_shape is not None:
expected_shape = shape_array(expected_shape)
starts = convert_args(starts, 'starts') nodes = {
ends = convert_args(ends, 'ends') **input_node,
axes = convert_args(axes, 'axes') **regular_op_with_empty_data('slice', {'op': 'Slice'}),
steps = convert_args(steps, 'steps') **starts,
if expected_shape is not None: **ends,
expected_shape = shape_array(expected_shape) **axes,
**steps,
}
nodes = { graph = build_graph(nodes,
**input_node, [('data_1', 'slice'),
**regular_op_with_empty_data('slice', {'op': 'Slice'}), *connect('starts', '1:slice'),
**starts, *connect('ends', '2:slice'),
**ends, *connect('axes', '3:slice'),
**axes, *connect('steps', '4:slice'),
**steps, *connect('slice', 'slice_d')])
}
graph = build_graph(nodes, graph.stage = 'middle'
[('data_1', 'slice'), slice_node = Node(graph, 'slice')
*connect('starts', '1:slice'),
*connect('ends', '2:slice'),
*connect('axes', '3:slice'),
*connect('steps', '4:slice'),
*connect('slice', 'slice_d')])
graph.stage = 'middle' Slice.infer(slice_node)
slice_node = Node(graph, 'slice') if expected_value is not None:
self.assertTrue(strict_compare_tensors(slice_node.out_node().value, expected_value))
self.assertTrue(strict_compare_tensors(slice_node.out_node().shape, expected_shape))
Slice.infer(slice_node)
if expected_value is not None: @generator
self.assertTrue(strict_compare_tensors(slice_node.out_node().value, expected_value)) class TestOvSliceOp(unittest.TestCase):
self.assertTrue(strict_compare_tensors(slice_node.out_node().shape, expected_shape)) @generate(*[
# standard case
([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], [0, 1], [3, 2], [0, 1], [1, 1],
[[5], [3], [6]], [3, 1]),
# negative bounds
([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], [0, 1], [3, -2], [0, 1], [1, 1],
[[5], [3], [6]], [3, 1]),
# unusual order of axes
([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], [0, 1], [3, -2], [1, 0], [1, 1],
[[2, 3, 5]], [1, 3]),
# when only input_shape is defined without values (one from bottom element is shape)
(None, [4, 5, 6], [1, 2], [4, 3], [0, 1], [1, 1], None, [3, 1, 6]),
# boundary case
(None, [4, 5, 6], [0, 2], [np.iinfo(np.int32).max, 3], [0, 1], [1, 1], None, [4, 1, 6]),
# boundary case
(None, [4, 5, 6], [np.iinfo(np.int32).min, 2], [3, 3], [0, 1], [1, 1], None, [3, 1, 6],),
# 1D input
([1, 3, 224, 224], [4], [1], [2], [0], [1], [3], [1]),
# 1D input with negative starts
(None, [4], [-1], [1], [0], [-1], None, [2]),
# 1D input with negative ends
(None, [4], [1], [-1], [0], [1], None, [2]),
# with rounding (e.g. take from 1st to 3rd with step 4 should give shape 1 not 0)
(None, [4], [1], [3], [0], [4], None, [1]),
# with rounding and negative steps (e.g. take from 1st to 3rd with step 4 should give shape 1 not 0)
(None, [10], [7], [3], [0], [-7], None, [1]),
# reversing the sequence of elements
(None, [10], [-1], [np.iinfo(np.int32).min], [0], [-1], None, [10]),
# dynamic dimensions cases
# starts are non-constant
([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], None, [3, 2], [0, 1], [1, 1], None,
[dynamic_dimension_value, dynamic_dimension_value]),
# ends are non-constant
([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], [0, 1], None, [0, 1], [1, 1], None,
[dynamic_dimension_value, dynamic_dimension_value]),
# axes are non-constant
([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], [0, 1], [3, -2], None, [1, 1], None,
[dynamic_dimension_value, dynamic_dimension_value]),
# steps are non-constant
([[4, 5, 6, 7], [2, 3, 5, 6], [5, 6, 8, 9], [5, 6, 8, 9]], [4, 4], [0, 1], [3, -2], [0, 1], None, None,
[dynamic_dimension_value, dynamic_dimension_value]),
# negative steps and since after normalization starts < ends output shape has 0-size dimension
(None, [20], [1], [-1], [0], [-2], None, [0]),
# since starts == ends output shape has 0-size dimension
(None, [4], [1], [1], [0], [1], None, [0]),
# since starts > ends output shape has 0-size dimension
(None, [4], [2], [1], [0], [1], None, [0])
])
def test_ov_slice_infer(self, inp_value, inp_shape, starts, ends, axes, steps, expected_value, expected_shape):
if inp_value is None:
input_node = shaped_data('data_1', int64_array(inp_shape))
else:
input_node = valued_data('data_1', int64_array(inp_value))
if inp_value is not None and inp_shape is not None:
assert np.array_equal(np.array(inp_value).shape, inp_shape)
def convert_args(val, name=''):
if val is not None:
return valued_const_with_data(name, int64_array(val))
else:
return shaped_const_with_data(name, [0]) # fake shape
starts = convert_args(starts, 'starts')
ends = convert_args(ends, 'ends')
steps = convert_args(steps, 'steps')
axes = convert_args(axes, 'axes')
if expected_shape is not None:
expected_shape = shape_array(expected_shape)
nodes = {
**input_node,
**regular_op_with_empty_data('slice', {'op': 'OvSlice'}),
**starts,
**ends,
**steps,
**axes,
}
graph = build_graph(nodes,
[('data_1', 'slice'),
*connect('starts', '1:slice'),
*connect('ends', '2:slice'),
*connect('steps', '3:slice'),
*connect('axes', '4:slice'),
*connect('slice', 'slice_d')])
graph.stage = 'middle'
slice_node = Node(graph, 'slice')
OvSlice.infer(slice_node)
if expected_value is not None:
self.assertTrue(strict_compare_tensors(slice_node.out_node().value, expected_value))
self.assertTrue(strict_compare_tensors(slice_node.out_node().shape, expected_shape))