[MO] set explicitly argument dtype to int for np.split (#9988)

* forced split argument dtype to int

* added unit-test

* fixed typo in split_test.py

* set explicitly np.int64 instead of np.int

* use split_length's dtype
This commit is contained in:
Pavel Esir 2022-02-09 12:16:33 +03:00 committed by GitHub
parent 25ca17e789
commit 654b025a26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 3 deletions

View File

@ -6,7 +6,7 @@ import logging as log
import numpy as np
from openvino.tools.mo.front.common.partial_infer.utils import is_fully_defined, dynamic_dimension, shape_delete, \
clarify_partial_shape, shape_array
clarify_partial_shape, shape_array, mo_array
from openvino.tools.mo.graph.graph import Graph, Node
from openvino.tools.mo.graph.perm_inputs import PermuteInputs
from openvino.tools.mo.ops.op import Op, PermuteAttrs
@ -98,7 +98,7 @@ class VariadicSplitBase(Op):
# value propagation
input_value = node.in_port(0).data.get_value()
if input_value is not None:
split = np.split(input_value, idxs[:-1], axis)
split = np.split(input_value, mo_array(idxs[:-1], dtype=split_lengths.dtype), axis)
for i, port in node.out_ports().items():
if not port.disconnected():
port.data.set_value(split[i])

View File

@ -7,7 +7,7 @@ import numpy as np
from generator import generator, generate
from openvino.tools.mo.front.common.partial_infer.utils import int64_array, shape_array, \
dynamic_dimension_value, dynamic_dimension, strict_compare_tensors
dynamic_dimension_value, dynamic_dimension, strict_compare_tensors, mo_array
from openvino.tools.mo.graph.graph import Node
from openvino.tools.mo.ops.split import AttributedSplit, AttributedVariadicSplit, VariadicSplit
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
@ -303,6 +303,36 @@ class TestVariadicSplitOp(unittest.TestCase):
for out in range(ont_nodes_count):
self.assertTrue(np.all(node.out_node(out).shape == int64_array([2, 12, lengths[out], 30])))
def test_variadic_split_value_inference_with_uint32(self):
axis = int64_array(2)
# because sum of Python int and Numpy np.uint64 gives float64
# but np.split accepts only integers and raises error for floats
# therefore needed to explicitly cast np.split arguments into integer
# added this test for that case
lengths = mo_array([2, 13, 10], dtype=np.uint64)
input_shape = mo_array([2, 12, 25, 30])
input_value = np.zeros(input_shape)
graph = build_graph(self.nodes, self.edges,
{
'split_input_data': {'shape': input_shape, 'value': input_value},
'split_axis_data': {'value': axis},
'split_lengths_data': {'value': lengths},
'split_op': {'out_ports_count': 4},
}
)
node = Node(graph, 'split_op')
for p in range(len(node.out_edges()), node.out_ports_count):
node.add_output_port(p)
VariadicSplit.infer(node)
ont_nodes_count = len(node.out_edges())
self.assertTrue(ont_nodes_count == 3)
for out in range(ont_nodes_count):
self.assertTrue(np.all(node.out_node(out).shape == int64_array([2, 12, lengths[out], 30])))
@generate(*[int64_array([[2], [2]]),
int64_array([2, 2])])
def test_negative_variadic_split_axis(self, axis):