[PyOV] Expose result op (#21482)

* [PyOV] Expose result op

* update docs with op.Result

* codestyle

* add test and fix flake8 errors

* add result test

* fix transformation tests

* update return type
This commit is contained in:
Anastasia Kuporosova 2023-12-08 14:57:11 +01:00 committed by GitHub
parent a3de4c17a1
commit ba34fa77e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 112 additions and 79 deletions

View File

@ -8,11 +8,10 @@ Low level wrappers for the c++ api in ov::op.
# flake8: noqa # flake8: noqa
import numpy as np
from openvino._pyopenvino.op import Constant from openvino._pyopenvino.op import Constant
from openvino._pyopenvino.op import assign from openvino._pyopenvino.op import assign
from openvino._pyopenvino.op import Parameter from openvino._pyopenvino.op import Parameter
from openvino._pyopenvino.op import if_op from openvino._pyopenvino.op import if_op
from openvino._pyopenvino.op import loop from openvino._pyopenvino.op import loop
from openvino._pyopenvino.op import tensor_iterator from openvino._pyopenvino.op import tensor_iterator
from openvino._pyopenvino.op import Result

View File

@ -144,7 +144,7 @@ from openvino.runtime.opset1.ops import region_yolo
from openvino.runtime.opset2.ops import reorg_yolo from openvino.runtime.opset2.ops import reorg_yolo
from openvino.runtime.opset1.ops import relu from openvino.runtime.opset1.ops import relu
from openvino.runtime.opset1.ops import reshape from openvino.runtime.opset1.ops import reshape
from openvino.runtime.opset1.ops import result from openvino.runtime.opset13.ops import result
from openvino.runtime.opset1.ops import reverse_sequence from openvino.runtime.opset1.ops import reverse_sequence
from openvino.runtime.opset3.ops import rnn_cell from openvino.runtime.opset3.ops import rnn_cell
from openvino.runtime.opset5.ops import rnn_sequence from openvino.runtime.opset5.ops import rnn_sequence

View File

@ -11,8 +11,8 @@ import numpy as np
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
from openvino.runtime import Node, Shape, Type from openvino.runtime import Node, Shape, Type, Output
from openvino.runtime.op import Constant from openvino.runtime.op import Constant, Result
from openvino.runtime.opset_utils import _get_node_factory from openvino.runtime.opset_utils import _get_node_factory
from openvino.runtime.utils.decorators import binary_op, nameable_op, unary_op from openvino.runtime.utils.decorators import binary_op, nameable_op, unary_op
from openvino.runtime.utils.types import ( from openvino.runtime.utils.types import (
@ -334,3 +334,15 @@ def constant(
_value, _shared_memory = _value.astype(_dtype), False _value, _shared_memory = _value.astype(_dtype), False
# Create Constant itself: # Create Constant itself:
return Constant(_value, shared_memory=_shared_memory) return Constant(_value, shared_memory=_shared_memory)
@unary_op
def result(data: Union[Node, Output, NumericData], name: Optional[str] = None) -> Node:
"""Return a node which represents an output of a graph (Model).
:param data: The tensor containing the input data
:return: Result node
"""
if isinstance(data, Node):
return Result(data.output(0))
return Result(data)

View File

@ -111,7 +111,7 @@ void regclass_graph_Model(py::module m) {
Create user-defined Model which is a representation of a model. Create user-defined Model which is a representation of a model.
:param results: List of results. :param results: List of results.
:type results: List[openvino.runtime.Node] :type results: List[op.Result]
:param sinks: List of Nodes to be used as Sinks (e.g. Assign ops). :param sinks: List of Nodes to be used as Sinks (e.g. Assign ops).
:type sinks: List[openvino.runtime.Node] :type sinks: List[openvino.runtime.Node]
:param parameters: List of parameters. :param parameters: List of parameters.
@ -221,7 +221,7 @@ void regclass_graph_Model(py::module m) {
Create user-defined Model which is a representation of a model Create user-defined Model which is a representation of a model
:param results: List of results. :param results: List of results.
:type results: List[openvino.runtime.Node] :type results: List[op.Result]
:param sinks: List of Nodes to be used as Sinks (e.g. Assign ops). :param sinks: List of Nodes to be used as Sinks (e.g. Assign ops).
:type sinks: List[openvino.runtime.Node] :type sinks: List[openvino.runtime.Node]
:param parameters: List of parameters. :param parameters: List of parameters.
@ -274,7 +274,7 @@ void regclass_graph_Model(py::module m) {
Create user-defined Model which is a representation of a model Create user-defined Model which is a representation of a model
:param results: List of results. :param results: List of results.
:type results: List[openvino.runtime.Node] :type results: List[op.Result]
:param parameters: List of parameters. :param parameters: List of parameters.
:type parameters: List[op.Parameter] :type parameters: List[op.Parameter]
:param variables: List of variables. :param variables: List of variables.
@ -538,7 +538,7 @@ void regclass_graph_Model(py::module m) {
Return a list of model outputs. Return a list of model outputs.
:return: a list of model's result nodes. :return: a list of model's result nodes.
:rtype: List[openvino.runtime.Node] :rtype: List[op.Result]
)"); )");
model.def_property_readonly("results", model.def_property_readonly("results",
&ov::Model::get_results, &ov::Model::get_results,
@ -546,7 +546,7 @@ void regclass_graph_Model(py::module m) {
Return a list of model outputs. Return a list of model outputs.
:return: a list of model's result nodes. :return: a list of model's result nodes.
:rtype: List[openvino.runtime.Node] :rtype: List[op.Result]
)"); )");
model.def("get_result", model.def("get_result",
&ov::Model::get_result, &ov::Model::get_result,
@ -554,7 +554,7 @@ void regclass_graph_Model(py::module m) {
Return single result. Return single result.
:return: Node object representing result. :return: Node object representing result.
:rtype: openvino.runtime.Node :rtype: op.Result
)"); )");
model.def_property_readonly("result", model.def_property_readonly("result",
&ov::Model::get_result, &ov::Model::get_result,
@ -562,7 +562,7 @@ void regclass_graph_Model(py::module m) {
Return single result. Return single result.
:return: Node object representing result. :return: Node object representing result.
:rtype: openvino.runtime.Node :rtype: op.Result
)"); )");
model.def("get_result_index", model.def("get_result_index",
(int64_t(ov::Model::*)(const ov::Output<ov::Node>&) const) & ov::Model::get_result_index, (int64_t(ov::Model::*)(const ov::Output<ov::Node>&) const) & ov::Model::get_result_index,
@ -747,7 +747,7 @@ void regclass_graph_Model(py::module m) {
Delete Result node from the list of results. Method will not delete node from graph. Delete Result node from the list of results. Method will not delete node from graph.
:param result: Result node to delete. :param result: Result node to delete.
:type result: openvino.runtime.Node :type result: op.Result
)"); )");
model.def("remove_parameter", model.def("remove_parameter",
@ -827,7 +827,7 @@ void regclass_graph_Model(py::module m) {
Method doesn't validate graph, it should be done manually after all changes. Method doesn't validate graph, it should be done manually after all changes.
:param results: new Result nodes. :param results: new Result nodes.
:type results: List[openvino.runtime.Node] :type results: List[op.Result]
)"); )");
model.def( model.def(

View File

@ -115,10 +115,10 @@ void regclass_graph_op_If(py::module m) {
Sets new output from the operation associated with results of each sub-graphs. Sets new output from the operation associated with results of each sub-graphs.
:param then_result: result from then_body. :param then_result: result from then_body.
:type then_result: openvino.runtime.Node :type then_result: op.Result
:param else_result: result from else_body. :param else_result: result from else_body.
:type else_result: openvino.runtime.Node :type else_result: op.Result
:return: output from operation. :return: output from operation.
:rtype: openvino.runtime.Output :rtype: openvino.runtime.Output

View File

@ -20,6 +20,8 @@ void regclass_graph_op_Result(py::module m) {
result.doc() = "openvino.runtime.op.Result wraps ov::op::v0::Result"; result.doc() = "openvino.runtime.op.Result wraps ov::op::v0::Result";
result.def(py::init<const ov::Output<ov::Node>&>());
result.def("get_output_partial_shape", &ov::Node::get_output_partial_shape, py::arg("index")); result.def("get_output_partial_shape", &ov::Node::get_output_partial_shape, py::arg("index"));
result.def("get_output_element_type", &ov::Node::get_output_element_type, py::arg("index")); result.def("get_output_element_type", &ov::Node::get_output_element_type, py::arg("index"));

View File

@ -25,6 +25,7 @@ from openvino.runtime import AxisVector, Coordinate, CoordinateDiff
from openvino._pyopenvino import DescriptorTensor from openvino._pyopenvino import DescriptorTensor
from openvino.runtime.utils.types import get_element_type from openvino.runtime.utils.types import get_element_type
from tests.utils.helpers import generate_model_with_memory
def test_graph_api(): def test_graph_api():
@ -554,12 +555,7 @@ def test_multiple_outputs():
def test_sink_model_ctor(): def test_sink_model_ctor():
input_data = ops.parameter([2, 2], name="input_data", dtype=np.float32) model = generate_model_with_memory(input_shape=[2, 2], data_type=np.float32)
rv = ops.read_value(input_data, "var_id_667", np.float32, [2, 2])
add = ops.add(rv, input_data, name="MemoryAdd")
node = ops.assign(add, "var_id_667")
res = ops.result(add, "res")
model = Model(results=[res], sinks=[node], parameters=[input_data], name="TestModel")
ordered_ops = model.get_ordered_ops() ordered_ops = model.get_ordered_ops()
op_types = [op.get_type_name() for op in ordered_ops] op_types = [op.get_type_name() for op in ordered_ops]
@ -570,7 +566,7 @@ def test_sink_model_ctor():
assert len(model.get_ops()) == 5 assert len(model.get_ops()) == 5
assert model.get_output_size() == 1 assert model.get_output_size() == 1
assert model.get_output_op(0).get_type_name() == "Result" assert model.get_output_op(0).get_type_name() == "Result"
assert model.get_output_element_type(0) == input_data.get_element_type() assert model.get_output_element_type(0) == model.get_parameters()[0].get_element_type()
assert list(model.get_output_shape(0)) == [2, 2] assert list(model.get_output_shape(0)) == [2, 2]
assert (model.get_parameters()[0].get_partial_shape()) == PartialShape([2, 2]) assert (model.get_parameters()[0].get_partial_shape()) == PartialShape([2, 2])
assert len(model.get_parameters()) == 1 assert len(model.get_parameters()) == 1

View File

@ -1,8 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2023 Intel Corporation # Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# flake8: noqa
import os import os
import numpy as np import numpy as np
@ -15,6 +14,7 @@ from openvino.runtime.passes import Manager, Serialize, ConstantFolding, Version
from tests.test_graph.util import count_ops_of_type from tests.test_graph.util import count_ops_of_type
from tests.utils.helpers import create_filename_for_test, compare_models from tests.utils.helpers import create_filename_for_test, compare_models
def create_model(): def create_model():
shape = [100, 100, 2] shape = [100, 100, 2]
parameter_a = ops.parameter(shape, dtype=np.float32, name="A") parameter_a = ops.parameter(shape, dtype=np.float32, name="A")
@ -40,7 +40,8 @@ def test_constant_folding():
assert count_ops_of_type(model, node_ceil) == 0 assert count_ops_of_type(model, node_ceil) == 0
assert count_ops_of_type(model, node_constant) == 1 assert count_ops_of_type(model, node_constant) == 1
new_const = model.get_results()[0].input(0).get_source_output().get_node() result = model.get_results()[0]
new_const = result.input(0).get_source_output().get_node()
values_out = new_const.get_vector() values_out = new_const.get_vector()
values_expected = [0.0, 1.0, 0.0, -2.0, 3.0, 3.0] values_expected = [0.0, 1.0, 0.0, -2.0, 3.0, 3.0]
@ -48,14 +49,14 @@ def test_constant_folding():
# request - https://docs.pytest.org/en/7.1.x/reference/reference.html#request # request - https://docs.pytest.org/en/7.1.x/reference/reference.html#request
@pytest.fixture @pytest.fixture()
def prepare_ir_paths(request, tmp_path): def prepare_ir_paths(request, tmp_path):
xml_path, bin_path = create_filename_for_test(request.node.name, tmp_path) xml_path, bin_path = create_filename_for_test(request.node.name, tmp_path)
yield xml_path, bin_path yield xml_path, bin_path
# IR Files deletion should be done after `Model` is destructed. # IR Files deletion should be done after `Model` is destructed.
# It may be achieved by splitting scopes (`Model` will be destructed # It may be achieved by splitting scopes (`Model` will be destructed
# just after test scope finished), or by calling `del Model` # just after test scope finished), or by calling `del Model`
os.remove(xml_path) os.remove(xml_path)
os.remove(bin_path) os.remove(bin_path)
@ -104,7 +105,7 @@ def test_serialize_separate_paths_args(prepare_ir_paths):
def test_serialize_pass_mixed_args_kwargs(prepare_ir_paths): def test_serialize_pass_mixed_args_kwargs(prepare_ir_paths):
core = Core() core = Core()
shape = [3, 2] shape = [3, 2]
parameter_a = ops.parameter(shape, dtype=np.float32, name="A") parameter_a = ops.parameter(shape, dtype=np.float32, name="A")
parameter_b = ops.parameter(shape, dtype=np.float32, name="B") parameter_b = ops.parameter(shape, dtype=np.float32, name="B")
@ -123,7 +124,7 @@ def test_serialize_pass_mixed_args_kwargs(prepare_ir_paths):
def test_serialize_pass_mixed_args_kwargs_v2(prepare_ir_paths): def test_serialize_pass_mixed_args_kwargs_v2(prepare_ir_paths):
core = Core() core = Core()
xml_path, bin_path = prepare_ir_paths xml_path, bin_path = prepare_ir_paths
model = create_model() model = create_model()
pass_manager = Manager() pass_manager = Manager()
@ -175,7 +176,7 @@ def test_default_version(prepare_ir_paths):
assert compare_models(model, res_model) assert compare_models(model, res_model)
def test_default_version_IR_V11_separate_paths(prepare_ir_paths): def test_default_version_ir_v11_separate_paths(prepare_ir_paths):
core = Core() core = Core()
xml_path, bin_path = prepare_ir_paths xml_path, bin_path = prepare_ir_paths

View File

@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2023 Intel Corporation # Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
@ -12,6 +13,7 @@ from openvino import Shape, Type
from openvino.runtime import AxisSet from openvino.runtime import AxisSet
from openvino.runtime.op import Constant, Parameter from openvino.runtime.op import Constant, Parameter
@pytest.mark.parametrize(("ov_op", "expected_ov_str", "expected_type"), [ @pytest.mark.parametrize(("ov_op", "expected_ov_str", "expected_type"), [
(lambda a, b: a + b, "Add", Type.f32), (lambda a, b: a + b, "Add", Type.f32),
(ov.add, "Add", Type.f32), (ov.add, "Add", Type.f32),
@ -34,9 +36,9 @@ from openvino.runtime.op import Constant, Parameter
def test_binary_op(ov_op, expected_ov_str, expected_type): def test_binary_op(ov_op, expected_ov_str, expected_type):
element_type = Type.f32 element_type = Type.f32
shape = Shape([2, 2]) shape = Shape([2, 2])
A = Parameter(element_type, shape) param1 = Parameter(element_type, shape)
B = Parameter(element_type, shape) param2 = Parameter(element_type, shape)
node = ov_op(A, B) node = ov_op(param1, param2)
assert node.get_type_name() == expected_ov_str assert node.get_type_name() == expected_ov_str
assert node.get_output_size() == 1 assert node.get_output_size() == 1
@ -48,10 +50,10 @@ def test_add_with_mul():
element_type = Type.f32 element_type = Type.f32
shape = Shape([4]) shape = Shape([4])
A = Parameter(element_type, shape) param1 = Parameter(element_type, shape)
B = Parameter(element_type, shape) param2 = Parameter(element_type, shape)
C = Parameter(element_type, shape) param3 = Parameter(element_type, shape)
node = ov.multiply(ov.add(A, B), C) node = ov.multiply(ov.add(param1, param2), param3)
assert node.get_type_name() == "Multiply" assert node.get_type_name() == "Multiply"
assert node.get_output_size() == 1 assert node.get_output_size() == 1
@ -85,8 +87,8 @@ def test_unary_op(ov_op, expected_ov_str):
element_type = Type.f32 element_type = Type.f32
shape = Shape([4]) shape = Shape([4])
A = Parameter(element_type, shape) param1 = Parameter(element_type, shape)
node = ov_op(A) node = ov_op(param1)
assert node.get_type_name() == expected_ov_str assert node.get_type_name() == expected_ov_str
assert node.get_output_size() == 1 assert node.get_output_size() == 1
@ -97,8 +99,8 @@ def test_unary_op(ov_op, expected_ov_str):
def test_reshape(): def test_reshape():
element_type = Type.f32 element_type = Type.f32
shape = Shape([2, 3]) shape = Shape([2, 3])
A = Parameter(element_type, shape) param1 = Parameter(element_type, shape)
node = ov.reshape(A, Shape([3, 2]), special_zero=False) node = ov.reshape(param1, Shape([3, 2]), special_zero=False)
assert node.get_type_name() == "Reshape" assert node.get_type_name() == "Reshape"
assert node.get_output_size() == 1 assert node.get_output_size() == 1
@ -108,8 +110,8 @@ def test_reshape():
def test_broadcast(): def test_broadcast():
element_type = Type.f32 element_type = Type.f32
A = Parameter(element_type, Shape([3])) param1 = Parameter(element_type, Shape([3]))
node = ov.broadcast(A, [3, 3]) node = ov.broadcast(param1, [3, 3])
assert node.get_type_name() == "Broadcast" assert node.get_type_name() == "Broadcast"
assert node.get_output_size() == 1 assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == [3, 3] assert list(node.get_output_shape(0)) == [3, 3]
@ -134,10 +136,10 @@ def test_constant(const, args, expectation):
def test_concat(): def test_concat():
element_type = Type.f32 element_type = Type.f32
A = Parameter(element_type, Shape([1, 2])) param1 = Parameter(element_type, Shape([1, 2]))
B = Parameter(element_type, Shape([1, 2])) param2 = Parameter(element_type, Shape([1, 2]))
C = Parameter(element_type, Shape([1, 2])) param3 = Parameter(element_type, Shape([1, 2]))
node = ov.concat([A, B, C], axis=0) node = ov.concat([param1, param2, param3], axis=0)
assert node.get_type_name() == "Concat" assert node.get_type_name() == "Concat"
assert node.get_output_size() == 1 assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == [3, 2] assert list(node.get_output_shape(0)) == [3, 2]
@ -162,10 +164,10 @@ def test_axisset():
def test_select(): def test_select():
element_type = Type.f32 element_type = Type.f32
A = Parameter(Type.boolean, Shape([1, 2])) param1 = Parameter(Type.boolean, Shape([1, 2]))
B = Parameter(element_type, Shape([1, 2])) param2 = Parameter(element_type, Shape([1, 2]))
C = Parameter(element_type, Shape([1, 2])) param3 = Parameter(element_type, Shape([1, 2]))
node = ov.select(A, B, C) node = ov.select(param1, param2, param3)
assert node.get_type_name() == "Select" assert node.get_type_name() == "Select"
assert node.get_output_size() == 1 assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == [1, 2] assert list(node.get_output_shape(0)) == [1, 2]
@ -175,7 +177,7 @@ def test_select():
def test_max_pool_1d(): def test_max_pool_1d():
element_type = Type.f32 element_type = Type.f32
shape = Shape([1, 1, 10]) shape = Shape([1, 1, 10])
A = Parameter(element_type, shape) param1 = Parameter(element_type, shape)
window_shape = [3] window_shape = [3]
strides = [1] * len(window_shape) strides = [1] * len(window_shape)
@ -187,7 +189,7 @@ def test_max_pool_1d():
idx_elem_type = "i32" idx_elem_type = "i32"
model = ov.max_pool( model = ov.max_pool(
A, param1,
strides, strides,
dilations, dilations,
pads_begin, pads_begin,
@ -204,10 +206,11 @@ def test_max_pool_1d():
assert model.get_output_element_type(0) == element_type assert model.get_output_element_type(0) == element_type
assert model.get_output_element_type(1) == Type.i32 assert model.get_output_element_type(1) == Type.i32
def test_max_pool_1d_with_strides(): def test_max_pool_1d_with_strides():
element_type = Type.f32 element_type = Type.f32
shape = Shape([1, 1, 10]) shape = Shape([1, 1, 10])
A = Parameter(element_type, shape) param1 = Parameter(element_type, shape)
window_shape = [3] window_shape = [3]
strides = [2] strides = [2]
pads_begin = [0] * len(window_shape) pads_begin = [0] * len(window_shape)
@ -218,7 +221,7 @@ def test_max_pool_1d_with_strides():
idx_elem_type = "i32" idx_elem_type = "i32"
model = ov.max_pool( model = ov.max_pool(
A, param1,
strides, strides,
dilations, dilations,
pads_begin, pads_begin,
@ -236,10 +239,11 @@ def test_max_pool_1d_with_strides():
assert model.get_output_element_type(0) == element_type assert model.get_output_element_type(0) == element_type
assert model.get_output_element_type(1) == Type.i32 assert model.get_output_element_type(1) == Type.i32
def test_max_pool_2d(): def test_max_pool_2d():
element_type = Type.f32 element_type = Type.f32
shape = Shape([1, 1, 10, 10]) shape = Shape([1, 1, 10, 10])
A = Parameter(element_type, shape) param1 = Parameter(element_type, shape)
window_shape = [3, 3] window_shape = [3, 3]
rounding_type = "floor" rounding_type = "floor"
auto_pad = "explicit" auto_pad = "explicit"
@ -251,7 +255,7 @@ def test_max_pool_2d():
pads_end = [0, 0] pads_end = [0, 0]
model = ov.max_pool( model = ov.max_pool(
A, param1,
strides, strides,
dilations, dilations,
pads_begin, pads_begin,
@ -272,7 +276,7 @@ def test_max_pool_2d():
def test_max_pool_2d_with_strides(): def test_max_pool_2d_with_strides():
element_type = Type.f32 element_type = Type.f32
shape = Shape([1, 1, 10, 10]) shape = Shape([1, 1, 10, 10])
A = Parameter(element_type, shape) param1 = Parameter(element_type, shape)
strides = [2, 2] strides = [2, 2]
dilations = [1, 1] dilations = [1, 1]
pads_begin = [0, 0] pads_begin = [0, 0]
@ -283,7 +287,7 @@ def test_max_pool_2d_with_strides():
idx_elem_type = "i32" idx_elem_type = "i32"
model = ov.max_pool( model = ov.max_pool(
A, param1,
strides, strides,
dilations, dilations,
pads_begin, pads_begin,

View File

@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from openvino import PartialShape, Model, Type
import openvino.runtime.opset13 as ops
from openvino.runtime.op import Result
def test_result():
param = ops.parameter(PartialShape([1]), dtype=np.float32, name="param")
relu1 = ops.relu(param, name="relu1")
result = Result(relu1.output(0))
assert result.get_output_element_type(0) == Type.f32
assert result.get_output_partial_shape(0) == PartialShape([1])
model = Model([result], [param], "test_model")
result2 = ops.result(relu1, "res2")
model.add_results([result2])
results = model.get_results()
assert len(results) == 2
assert results[1].get_output_element_type(0) == Type.f32
assert results[1].get_output_partial_shape(0) == PartialShape([1])
model.remove_result(result)
assert len(model.results) == 1

View File

@ -1,24 +1,18 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2023 Intel Corporation # Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# flake8: noqa
import numpy as np import numpy as np
import pytest import pytest
import openvino.runtime.opset13 as ov import openvino.runtime.opset13 as ov
from openvino import Type from openvino import Type
@pytest.mark.parametrize( @pytest.mark.parametrize("pad_mode", [
"pad_mode", "constant", "edge", "reflect", "symmetric",
[ ])
"constant",
"edge",
"reflect",
"symmetric",
]
)
def test_pad_mode(pad_mode): def test_pad_mode(pad_mode):
pads_begin = np.array([0, 1], dtype=np.int32) pads_begin = np.array([0, 1], dtype=np.int32)
pads_end = np.array([2, 3], dtype=np.int32) pads_end = np.array([2, 3], dtype=np.int32)
@ -32,13 +26,10 @@ def test_pad_mode(pad_mode):
assert model.get_output_element_type(0) == Type.i32 assert model.get_output_element_type(0) == Type.i32
@pytest.mark.parametrize( @pytest.mark.parametrize(("pads_begin", "pads_end", "output_shape"), [
("pads_begin", "pads_end", "output_shape"), ([-1, -1], [-1, -1], [1, 2]),
[ ([2, -1], [-1, 3], [4, 6]),
([-1, -1], [-1, -1], [1, 2]), ])
([2, -1], [-1, 3], [4, 6]),
]
)
def test_pad_being_and_end(pads_begin, pads_end, output_shape): def test_pad_being_and_end(pads_begin, pads_end, output_shape):
input_param = ov.parameter((3, 4), name="input", dtype=np.int32) input_param = ov.parameter((3, 4), name="input", dtype=np.int32)
model = ov.pad(input_param, pads_begin, pads_end, "constant") model = ov.pad(input_param, pads_begin, pads_end, "constant")