[PyOV] fix if_op return types (#20014)

* [PyOV] fix if_op return types

* fix if

* other types

* test for if

* add tests, try fix loop, ti

* fix ci
This commit is contained in:
Anastasia Kuporosova 2023-09-28 09:35:54 +02:00 committed by GitHub
parent ca7031f69d
commit c90bcbf5eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 57 additions and 13 deletions

View File

@ -998,6 +998,18 @@ void regclass_graph_Model(py::module m) {
:type path: str
)");
model.def(
"_get_raw_address",
[](ov::Model& self) {
return reinterpret_cast<uint64_t>(&self);
},
R"(
Returns raw address of the Model object.
:return: a raw address of the Model object.
:rtype: int
)");
model.def_property_readonly("inputs", (std::vector<ov::Output<ov::Node>>(ov::Model::*)()) & ov::Model::inputs);
model.def_property_readonly("outputs", (std::vector<ov::Output<ov::Node>>(ov::Model::*)()) & ov::Model::outputs);
model.def_property_readonly("name", &ov::Model::get_name);

View File

@ -49,9 +49,14 @@ void regclass_graph_op_If(py::module m) {
:rtype: openvino.impl.op.If
)");
cls.def("get_else_body",
&ov::op::v8::If::get_else_body,
R"(
cls.def(
"get_else_body",
[](ov::op::v8::If& self) {
auto model = self.get_else_body();
py::type model_class = py::module_::import("openvino.runtime").attr("Model");
return model_class(py::cast(model));
},
R"(
Gets else_body as Model object.
:return: else_body as Model object.
@ -119,10 +124,15 @@ void regclass_graph_op_If(py::module m) {
:rtype: openvino.runtime.Output
)");
cls.def("get_function",
&ov::op::util::MultiSubGraphOp::get_function,
py::arg("index"),
R"(
cls.def(
"get_function",
[](ov::op::v8::If& self, size_t index) {
auto model = self.get_function(index);
py::type model_class = py::module_::import("openvino.runtime").attr("Model");
return model_class(py::cast(model));
},
py::arg("index"),
R"(
Gets internal sub-graph by index in MultiSubGraphOp.
:param index: sub-graph's index in op.

View File

@ -92,7 +92,9 @@ void regclass_graph_op_Loop(py::module m) {
py::arg("successive_value"));
cls.def("get_function", [](const std::shared_ptr<ov::op::v5::Loop>& self) {
return self->get_function();
auto model = self->get_function();
py::type model_class = py::module_::import("openvino.runtime").attr("Model");
return model_class(py::cast(model));
});
cls.def(

View File

@ -55,11 +55,15 @@ void regclass_graph_op_TensorIterator(py::module m) {
py::arg("successive_value"));
cls.def("get_body", [](const std::shared_ptr<ov::op::v0::TensorIterator>& self) {
return self->get_body();
auto model = self->get_body();
py::type model_class = py::module_::import("openvino.runtime").attr("Model");
return model_class(py::cast(model));
});
cls.def("get_function", [](const std::shared_ptr<ov::op::v0::TensorIterator>& self) {
return self->get_function();
auto model = self->get_function();
py::type model_class = py::module_::import("openvino.runtime").attr("Model");
return model_class(py::cast(model));
});
cls.def(

View File

@ -8,6 +8,8 @@ from openvino.runtime import Model
from openvino.runtime.op.util import InvariantInputDescription, BodyOutputDescription
from tests.utils.helpers import compare_models
def create_simple_if_with_two_outputs(condition_val):
condition = ov.constant(condition_val, dtype=bool)
@ -191,7 +193,11 @@ def test_simple_if_basic():
if_node = ov.if_op(condition.output(0))
if_node.set_function(0, then_body)
assert if_node.get_function(0) == then_body
subgraph_func = if_node.get_function(0)
assert type(subgraph_func) == type(then_body)
assert compare_models(subgraph_func, then_body)
assert subgraph_func._get_raw_address() == then_body._get_raw_address()
if_node.set_input_descriptions(0, then_body_inputs)
if_node.set_output_descriptions(1, else_body_outputs)

View File

@ -13,6 +13,7 @@ from openvino.runtime.op.util import (
MergedInputDescription,
ConcatOutputDescription,
)
from tests.utils.helpers import compare_models
def test_simple_loop():
@ -139,7 +140,11 @@ def test_loop_basic():
loop.get_iter_value(curr_cma.output(0), -1)
loop.get_concatenated_slices(cma_hist.output(0), 0, 1, 1, -1, 0)
assert loop.get_function() == graph_body
subgraph_func = loop.get_function()
assert type(subgraph_func) == type(graph_body)
assert subgraph_func._get_raw_address() == graph_body._get_raw_address()
assert compare_models(subgraph_func, graph_body)
assert loop.get_special_body_ports() == body_ports
assert loop.get_num_iterations() == 16

View File

@ -13,6 +13,7 @@ from openvino.runtime.op.util import (
MergedInputDescription,
ConcatOutputDescription,
)
from tests.utils.helpers import compare_models
def test_simple_tensor_iterator():
@ -118,7 +119,11 @@ def test_tensor_iterator_basic():
ti.get_iter_value(curr_cma.output(0), -1)
ti.get_concatenated_slices(cma_hist.output(0), 0, 1, 1, -1, 0)
assert ti.get_function() == graph_body
subgraph_func = ti.get_function()
assert type(subgraph_func) == type(graph_body)
assert compare_models(subgraph_func, graph_body)
assert subgraph_func._get_raw_address() == graph_body._get_raw_address()
assert ti.get_num_iterations() == 16
input_desc = ti.get_input_descriptions()