[PYTHON] Add OV Types support to parameter and constant from opsets (#10489)
* Add OV Types to parameter and constant node factory, refactor tests and error handling * Fix name mismatch in docstring * Fix docs and hints
This commit is contained in:
parent
828d9d810a
commit
206442fb19
@ -1,16 +1,16 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""openvino exceptions hierarchy. All exceptions are descendants of NgraphError."""
|
||||
"""openvino exceptions hierarchy. All exceptions are descendants of OVError."""
|
||||
|
||||
|
||||
class NgraphError(Exception):
|
||||
"""Base class for Ngraph exceptions."""
|
||||
class OVError(Exception):
|
||||
"""Base class for OV exceptions."""
|
||||
|
||||
|
||||
class UserInputError(NgraphError):
|
||||
class UserInputError(OVError):
|
||||
"""User provided unexpected input."""
|
||||
|
||||
|
||||
class NgraphTypeError(NgraphError, TypeError):
|
||||
class OVTypeError(OVError, TypeError):
|
||||
"""Type mismatch error."""
|
||||
|
@ -2,12 +2,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""Factory functions for all openvino ops."""
|
||||
from typing import Callable, Iterable, List, Optional, Set, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
from openvino.runtime import Node, PartialShape, Shape
|
||||
from openvino.runtime import Node, PartialShape, Type
|
||||
from openvino.runtime.op import Constant, Parameter, tensor_iterator
|
||||
from openvino.runtime.opset_utils import _get_node_factory
|
||||
from openvino.runtime.utils.decorators import binary_op, nameable_op, unary_op
|
||||
@ -312,7 +312,11 @@ def concat(nodes: List[NodeInput], axis: int, name: Optional[str] = None) -> Nod
|
||||
|
||||
|
||||
@nameable_op
|
||||
def constant(value: NumericData, dtype: NumericType = None, name: Optional[str] = None) -> Constant:
|
||||
def constant(
|
||||
value: NumericData,
|
||||
dtype: Union[NumericType, Type] = None,
|
||||
name: Optional[str] = None,
|
||||
) -> Constant:
|
||||
"""Create a Constant node from provided value.
|
||||
|
||||
:param value: One of: array of values or scalar to initialize node with.
|
||||
@ -1544,7 +1548,6 @@ def matmul(
|
||||
:param transpose_b: should the second matrix be transposed
|
||||
returns MatMul operation node
|
||||
"""
|
||||
print("transpose_a", transpose_a, "transpose_b", transpose_b)
|
||||
return _get_node_factory_opset1().create(
|
||||
"MatMul", as_nodes(data_a, data_b), {"transpose_a": transpose_a, "transpose_b": transpose_b}
|
||||
)
|
||||
@ -1792,11 +1795,13 @@ def pad(
|
||||
|
||||
@nameable_op
|
||||
def parameter(
|
||||
shape: TensorShape, dtype: NumericType = np.float32, name: Optional[str] = None
|
||||
shape: TensorShape, dtype: Union[NumericType, Type] = np.float32, name: Optional[str] = None
|
||||
) -> Parameter:
|
||||
"""Return an openvino Parameter object."""
|
||||
element_type = get_element_type(dtype)
|
||||
return Parameter(element_type, PartialShape(shape))
|
||||
return Parameter(get_element_type(dtype)
|
||||
if isinstance(dtype, (type, np.dtype))
|
||||
else dtype,
|
||||
PartialShape(shape))
|
||||
|
||||
|
||||
@binary_op
|
||||
|
@ -8,9 +8,8 @@ from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from openvino.runtime.exceptions import NgraphTypeError
|
||||
from openvino.runtime import Node, Shape, Output
|
||||
from openvino.runtime import Type as NgraphType
|
||||
from openvino.runtime.exceptions import OVTypeError
|
||||
from openvino.runtime import Node, Shape, Output, Type
|
||||
from openvino.runtime.op import Constant
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -22,19 +21,19 @@ ScalarData = Union[int, float]
|
||||
NodeInput = Union[Node, NumericData]
|
||||
|
||||
openvino_to_numpy_types_map = [
|
||||
(NgraphType.boolean, np.bool),
|
||||
(NgraphType.f16, np.float16),
|
||||
(NgraphType.f32, np.float32),
|
||||
(NgraphType.f64, np.float64),
|
||||
(NgraphType.i8, np.int8),
|
||||
(NgraphType.i16, np.int16),
|
||||
(NgraphType.i32, np.int32),
|
||||
(NgraphType.i64, np.int64),
|
||||
(NgraphType.u8, np.uint8),
|
||||
(NgraphType.u16, np.uint16),
|
||||
(NgraphType.u32, np.uint32),
|
||||
(NgraphType.u64, np.uint64),
|
||||
(NgraphType.bf16, np.uint16),
|
||||
(Type.boolean, np.bool),
|
||||
(Type.f16, np.float16),
|
||||
(Type.f32, np.float32),
|
||||
(Type.f64, np.float64),
|
||||
(Type.i8, np.int8),
|
||||
(Type.i16, np.int16),
|
||||
(Type.i32, np.int32),
|
||||
(Type.i64, np.int64),
|
||||
(Type.u8, np.uint8),
|
||||
(Type.u16, np.uint16),
|
||||
(Type.u32, np.uint32),
|
||||
(Type.u64, np.uint64),
|
||||
(Type.bf16, np.uint16),
|
||||
]
|
||||
|
||||
openvino_to_numpy_types_str_map = [
|
||||
@ -53,23 +52,23 @@ openvino_to_numpy_types_str_map = [
|
||||
]
|
||||
|
||||
|
||||
def get_element_type(data_type: NumericType) -> NgraphType:
|
||||
def get_element_type(data_type: NumericType) -> Type:
|
||||
"""Return an ngraph element type for a Python type or numpy.dtype."""
|
||||
if data_type is int:
|
||||
log.warning("Converting int type of undefined bitwidth to 32-bit ngraph integer.")
|
||||
return NgraphType.i32
|
||||
return Type.i32
|
||||
|
||||
if data_type is float:
|
||||
log.warning("Converting float type of undefined bitwidth to 32-bit ngraph float.")
|
||||
return NgraphType.f32
|
||||
return Type.f32
|
||||
|
||||
ng_type = next(
|
||||
(ng_type for (ng_type, np_type) in openvino_to_numpy_types_map if np_type == data_type), None
|
||||
ov_type = next(
|
||||
(ov_type for (ov_type, np_type) in openvino_to_numpy_types_map if np_type == data_type), None
|
||||
)
|
||||
if ng_type:
|
||||
return ng_type
|
||||
if ov_type:
|
||||
return ov_type
|
||||
|
||||
raise NgraphTypeError("Unidentified data type %s", data_type)
|
||||
raise OVTypeError("Unidentified data type %s", data_type)
|
||||
|
||||
|
||||
def get_element_type_str(data_type: NumericType) -> str:
|
||||
@ -82,27 +81,27 @@ def get_element_type_str(data_type: NumericType) -> str:
|
||||
log.warning("Converting float type of undefined bitwidth to 32-bit ngraph float.")
|
||||
return "f32"
|
||||
|
||||
ng_type = next(
|
||||
(ng_type for (ng_type, np_type) in openvino_to_numpy_types_str_map if np_type == data_type),
|
||||
ov_type = next(
|
||||
(ov_type for (ov_type, np_type) in openvino_to_numpy_types_str_map if np_type == data_type),
|
||||
None,
|
||||
)
|
||||
if ng_type:
|
||||
return ng_type
|
||||
if ov_type:
|
||||
return ov_type
|
||||
|
||||
raise NgraphTypeError("Unidentified data type %s", data_type)
|
||||
raise OVTypeError("Unidentified data type %s", data_type)
|
||||
|
||||
|
||||
def get_dtype(ngraph_type: NgraphType) -> np.dtype:
|
||||
"""Return a numpy.dtype for an ngraph element type."""
|
||||
def get_dtype(openvino_type: Type) -> np.dtype:
|
||||
"""Return a numpy.dtype for an openvino element type."""
|
||||
np_type = next(
|
||||
(np_type for (ng_type, np_type) in openvino_to_numpy_types_map if ng_type == ngraph_type),
|
||||
(np_type for (ov_type, np_type) in openvino_to_numpy_types_map if ov_type == openvino_type),
|
||||
None,
|
||||
)
|
||||
|
||||
if np_type:
|
||||
return np.dtype(np_type)
|
||||
|
||||
raise NgraphTypeError("Unidentified data type %s", ngraph_type)
|
||||
raise OVTypeError("Unidentified data type %s", openvino_type)
|
||||
|
||||
|
||||
def get_ndarray(data: NumericData) -> np.ndarray:
|
||||
@ -121,11 +120,11 @@ def get_shape(data: NumericData) -> TensorShape:
|
||||
return []
|
||||
|
||||
|
||||
def make_constant_node(value: NumericData, dtype: NumericType = None) -> Constant:
|
||||
"""Return an ngraph Constant node with the specified value."""
|
||||
def make_constant_node(value: NumericData, dtype: Union[NumericType, Type] = None) -> Constant:
|
||||
"""Return an openvino Constant node with the specified value."""
|
||||
ndarray = get_ndarray(value)
|
||||
if dtype:
|
||||
element_type = get_element_type(dtype)
|
||||
if dtype is not None:
|
||||
element_type = get_element_type(dtype) if isinstance(dtype, (type, np.dtype)) else dtype
|
||||
else:
|
||||
element_type = get_element_type(ndarray.dtype)
|
||||
|
||||
|
@ -18,17 +18,19 @@ from openvino.runtime import Tensor
|
||||
from openvino.pyopenvino import DescriptorTensor
|
||||
from openvino.runtime.op import Parameter
|
||||
from tests.runtime import get_runtime
|
||||
from openvino.runtime.utils.types import get_dtype
|
||||
from tests.test_ngraph.util import run_op_node
|
||||
|
||||
|
||||
def test_ngraph_function_api():
|
||||
shape = [2, 2]
|
||||
parameter_a = ops.parameter(shape, dtype=np.float32, name="A")
|
||||
parameter_b = ops.parameter(shape, dtype=np.float32, name="B")
|
||||
parameter_b = ops.parameter(shape, dtype=Type.f32, name="B")
|
||||
parameter_c = ops.parameter(shape, dtype=np.float32, name="C")
|
||||
model = (parameter_a + parameter_b) * parameter_c
|
||||
|
||||
assert parameter_a.element_type == Type.f32
|
||||
assert parameter_b.element_type == Type.f32
|
||||
assert parameter_a.partial_shape == PartialShape([2, 2])
|
||||
parameter_a.layout = ov.Layout("NC")
|
||||
assert parameter_a.layout == ov.Layout("NC")
|
||||
@ -74,6 +76,17 @@ def test_ngraph_function_api():
|
||||
np.uint16,
|
||||
np.uint32,
|
||||
np.uint64,
|
||||
Type.f16,
|
||||
Type.f32,
|
||||
Type.f64,
|
||||
Type.i8,
|
||||
Type.i16,
|
||||
Type.i32,
|
||||
Type.i64,
|
||||
Type.u8,
|
||||
Type.u16,
|
||||
Type.u32,
|
||||
Type.u64,
|
||||
],
|
||||
)
|
||||
def test_simple_computation_on_ndarrays(dtype):
|
||||
@ -86,17 +99,19 @@ def test_simple_computation_on_ndarrays(dtype):
|
||||
model = (parameter_a + parameter_b) * parameter_c
|
||||
computation = runtime.computation(model, parameter_a, parameter_b, parameter_c)
|
||||
|
||||
value_a = np.array([[1, 2], [3, 4]], dtype=dtype)
|
||||
value_b = np.array([[5, 6], [7, 8]], dtype=dtype)
|
||||
value_c = np.array([[2, 3], [4, 5]], dtype=dtype)
|
||||
result = computation(value_a, value_b, value_c)
|
||||
assert np.allclose(result, np.array([[12, 24], [40, 60]], dtype=dtype))
|
||||
np_dtype = get_dtype(dtype) if isinstance(dtype, Type) else dtype
|
||||
|
||||
value_a = np.array([[9, 10], [11, 12]], dtype=dtype)
|
||||
value_b = np.array([[13, 14], [15, 16]], dtype=dtype)
|
||||
value_c = np.array([[5, 4], [3, 2]], dtype=dtype)
|
||||
value_a = np.array([[1, 2], [3, 4]], dtype=np_dtype)
|
||||
value_b = np.array([[5, 6], [7, 8]], dtype=np_dtype)
|
||||
value_c = np.array([[2, 3], [4, 5]], dtype=np_dtype)
|
||||
result = computation(value_a, value_b, value_c)
|
||||
assert np.allclose(result, np.array([[110, 96], [78, 56]], dtype=dtype))
|
||||
assert np.allclose(result, np.array([[12, 24], [40, 60]], dtype=np_dtype))
|
||||
|
||||
value_a = np.array([[9, 10], [11, 12]], dtype=np_dtype)
|
||||
value_b = np.array([[13, 14], [15, 16]], dtype=np_dtype)
|
||||
value_c = np.array([[5, 4], [3, 2]], dtype=np_dtype)
|
||||
result = computation(value_a, value_b, value_c)
|
||||
assert np.allclose(result, np.array([[110, 96], [78, 56]], dtype=np_dtype))
|
||||
|
||||
|
||||
def test_serialization():
|
||||
|
@ -34,7 +34,7 @@ def einsum_op_exec(input_shapes: list, equation: str, data_type: np.dtype,
|
||||
ng_inputs = []
|
||||
np_inputs = []
|
||||
for i in range(num_inputs):
|
||||
input_i = np.random.random_integers(10, size=input_shapes[i]).astype(data_type)
|
||||
input_i = np.random.randint(1, 10 + 1, size=input_shapes[i]).astype(data_type)
|
||||
np_inputs.append(input_i)
|
||||
ng_inputs.append(ov.parameter(input_i.shape, dtype=data_type))
|
||||
|
||||
|
@ -482,6 +482,30 @@ def test_constant():
|
||||
assert np.allclose(result, expected)
|
||||
|
||||
|
||||
def test_constant_opset_ov_type():
|
||||
parameter_list = []
|
||||
function = Model([ov.constant(np.arange(9).reshape(3, 3), Type.f32)], parameter_list, "test")
|
||||
|
||||
runtime = get_runtime()
|
||||
computation = runtime.computation(function, *parameter_list)
|
||||
result = computation()[0]
|
||||
|
||||
expected = np.arange(9).reshape(3, 3)
|
||||
assert np.allclose(result, expected)
|
||||
|
||||
|
||||
def test_constant_opset_numpy_type():
|
||||
parameter_list = []
|
||||
function = Model([ov.constant(np.arange(9).reshape(3, 3), np.float32)], parameter_list, "test")
|
||||
|
||||
runtime = get_runtime()
|
||||
computation = runtime.computation(function, *parameter_list)
|
||||
result = computation()[0]
|
||||
|
||||
expected = np.arange(9).reshape(3, 3)
|
||||
assert np.allclose(result, expected)
|
||||
|
||||
|
||||
def test_concat():
|
||||
|
||||
element_type = Type.f32
|
||||
|
@ -48,7 +48,7 @@ def test_fake_quantize():
|
||||
input_high_value = np.float32(23)
|
||||
output_low_value = np.float32(2)
|
||||
output_high_value = np.float32(16)
|
||||
levels = np.float32(4)
|
||||
levels = np.int32(4)
|
||||
|
||||
data_shape = [1, 2, 3, 4]
|
||||
bound_shape = []
|
||||
@ -114,7 +114,7 @@ def test_depth_to_space():
|
||||
dtype=np.float32,
|
||||
)
|
||||
mode = "blocks_first"
|
||||
block_size = np.float32(2)
|
||||
block_size = np.int32(2)
|
||||
|
||||
data_shape = [1, 4, 2, 3]
|
||||
parameter_data = ov.parameter(data_shape, name="Data", dtype=np.float32)
|
||||
|
@ -7,7 +7,7 @@ import onnx.mapping
|
||||
import pytest
|
||||
from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
|
||||
|
||||
from openvino.runtime.exceptions import NgraphTypeError
|
||||
from openvino.runtime.exceptions import OVTypeError
|
||||
from tests.runtime import get_runtime
|
||||
from tests.test_onnx.utils import get_node_model, import_onnx_model, run_model, run_node
|
||||
|
||||
@ -425,7 +425,7 @@ def test_cast_errors():
|
||||
|
||||
graph = make_graph([node], "compute_graph", input_tensors, output_tensors)
|
||||
model = make_model(graph, producer_name="NgraphBackend")
|
||||
with pytest.raises((RuntimeError, NgraphTypeError)):
|
||||
with pytest.raises((RuntimeError, OVTypeError)):
|
||||
import_onnx_model(model)
|
||||
|
||||
# unsupported output tensor data type:
|
||||
|
Loading…
Reference in New Issue
Block a user