[MO] set explicitly argument dtype to int for np.split (#9988)
* forced split argument dtype to int * added unit-test * fixed typo in split_test.py * set explicitly np.int64 instead of np.int * use split_length's dtype
This commit is contained in:
parent
25ca17e789
commit
654b025a26
@ -6,7 +6,7 @@ import logging as log
|
||||
import numpy as np
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import is_fully_defined, dynamic_dimension, shape_delete, \
|
||||
clarify_partial_shape, shape_array
|
||||
clarify_partial_shape, shape_array, mo_array
|
||||
from openvino.tools.mo.graph.graph import Graph, Node
|
||||
from openvino.tools.mo.graph.perm_inputs import PermuteInputs
|
||||
from openvino.tools.mo.ops.op import Op, PermuteAttrs
|
||||
@ -98,7 +98,7 @@ class VariadicSplitBase(Op):
|
||||
# value propagation
|
||||
input_value = node.in_port(0).data.get_value()
|
||||
if input_value is not None:
|
||||
split = np.split(input_value, idxs[:-1], axis)
|
||||
split = np.split(input_value, mo_array(idxs[:-1], dtype=split_lengths.dtype), axis)
|
||||
for i, port in node.out_ports().items():
|
||||
if not port.disconnected():
|
||||
port.data.set_value(split[i])
|
||||
|
@ -7,7 +7,7 @@ import numpy as np
|
||||
from generator import generator, generate
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array, shape_array, \
|
||||
dynamic_dimension_value, dynamic_dimension, strict_compare_tensors
|
||||
dynamic_dimension_value, dynamic_dimension, strict_compare_tensors, mo_array
|
||||
from openvino.tools.mo.graph.graph import Node
|
||||
from openvino.tools.mo.ops.split import AttributedSplit, AttributedVariadicSplit, VariadicSplit
|
||||
from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
@ -303,6 +303,36 @@ class TestVariadicSplitOp(unittest.TestCase):
|
||||
for out in range(ont_nodes_count):
|
||||
self.assertTrue(np.all(node.out_node(out).shape == int64_array([2, 12, lengths[out], 30])))
|
||||
|
||||
def test_variadic_split_value_inference_with_uint32(self):
|
||||
axis = int64_array(2)
|
||||
# because sum of Python int and Numpy np.uint64 gives float64
|
||||
|
||||
# but np.split accepts only integers and raises error for floats
|
||||
# therefore needed to explicitly cast np.split arguments into integer
|
||||
# added this test for that case
|
||||
lengths = mo_array([2, 13, 10], dtype=np.uint64)
|
||||
input_shape = mo_array([2, 12, 25, 30])
|
||||
input_value = np.zeros(input_shape)
|
||||
|
||||
graph = build_graph(self.nodes, self.edges,
|
||||
{
|
||||
'split_input_data': {'shape': input_shape, 'value': input_value},
|
||||
'split_axis_data': {'value': axis},
|
||||
'split_lengths_data': {'value': lengths},
|
||||
'split_op': {'out_ports_count': 4},
|
||||
}
|
||||
)
|
||||
node = Node(graph, 'split_op')
|
||||
for p in range(len(node.out_edges()), node.out_ports_count):
|
||||
node.add_output_port(p)
|
||||
|
||||
VariadicSplit.infer(node)
|
||||
|
||||
ont_nodes_count = len(node.out_edges())
|
||||
self.assertTrue(ont_nodes_count == 3)
|
||||
for out in range(ont_nodes_count):
|
||||
self.assertTrue(np.all(node.out_node(out).shape == int64_array([2, 12, lengths[out], 30])))
|
||||
|
||||
@generate(*[int64_array([[2], [2]]),
|
||||
int64_array([2, 2])])
|
||||
def test_negative_variadic_split_axis(self, axis):
|
||||
|
Loading…
Reference in New Issue
Block a user