[PyOV] fix if_op return types (#20014)
* [PyOV] fix if_op return types * fix if * other types * test for if * add tests, try fix loop, ti * fix ci
This commit is contained in:
parent
ca7031f69d
commit
c90bcbf5eb
@ -998,6 +998,18 @@ void regclass_graph_Model(py::module m) {
|
|||||||
:type path: str
|
:type path: str
|
||||||
)");
|
)");
|
||||||
|
|
||||||
|
model.def(
|
||||||
|
"_get_raw_address",
|
||||||
|
[](ov::Model& self) {
|
||||||
|
return reinterpret_cast<uint64_t>(&self);
|
||||||
|
},
|
||||||
|
R"(
|
||||||
|
Returns raw address of the Model object.
|
||||||
|
|
||||||
|
:return: a raw address of the Model object.
|
||||||
|
:rtype: int
|
||||||
|
)");
|
||||||
|
|
||||||
model.def_property_readonly("inputs", (std::vector<ov::Output<ov::Node>>(ov::Model::*)()) & ov::Model::inputs);
|
model.def_property_readonly("inputs", (std::vector<ov::Output<ov::Node>>(ov::Model::*)()) & ov::Model::inputs);
|
||||||
model.def_property_readonly("outputs", (std::vector<ov::Output<ov::Node>>(ov::Model::*)()) & ov::Model::outputs);
|
model.def_property_readonly("outputs", (std::vector<ov::Output<ov::Node>>(ov::Model::*)()) & ov::Model::outputs);
|
||||||
model.def_property_readonly("name", &ov::Model::get_name);
|
model.def_property_readonly("name", &ov::Model::get_name);
|
||||||
|
@ -49,9 +49,14 @@ void regclass_graph_op_If(py::module m) {
|
|||||||
:rtype: openvino.impl.op.If
|
:rtype: openvino.impl.op.If
|
||||||
)");
|
)");
|
||||||
|
|
||||||
cls.def("get_else_body",
|
cls.def(
|
||||||
&ov::op::v8::If::get_else_body,
|
"get_else_body",
|
||||||
R"(
|
[](ov::op::v8::If& self) {
|
||||||
|
auto model = self.get_else_body();
|
||||||
|
py::type model_class = py::module_::import("openvino.runtime").attr("Model");
|
||||||
|
return model_class(py::cast(model));
|
||||||
|
},
|
||||||
|
R"(
|
||||||
Gets else_body as Model object.
|
Gets else_body as Model object.
|
||||||
|
|
||||||
:return: else_body as Model object.
|
:return: else_body as Model object.
|
||||||
@ -119,10 +124,15 @@ void regclass_graph_op_If(py::module m) {
|
|||||||
:rtype: openvino.runtime.Output
|
:rtype: openvino.runtime.Output
|
||||||
)");
|
)");
|
||||||
|
|
||||||
cls.def("get_function",
|
cls.def(
|
||||||
&ov::op::util::MultiSubGraphOp::get_function,
|
"get_function",
|
||||||
py::arg("index"),
|
[](ov::op::v8::If& self, size_t index) {
|
||||||
R"(
|
auto model = self.get_function(index);
|
||||||
|
py::type model_class = py::module_::import("openvino.runtime").attr("Model");
|
||||||
|
return model_class(py::cast(model));
|
||||||
|
},
|
||||||
|
py::arg("index"),
|
||||||
|
R"(
|
||||||
Gets internal sub-graph by index in MultiSubGraphOp.
|
Gets internal sub-graph by index in MultiSubGraphOp.
|
||||||
|
|
||||||
:param index: sub-graph's index in op.
|
:param index: sub-graph's index in op.
|
||||||
|
@ -92,7 +92,9 @@ void regclass_graph_op_Loop(py::module m) {
|
|||||||
py::arg("successive_value"));
|
py::arg("successive_value"));
|
||||||
|
|
||||||
cls.def("get_function", [](const std::shared_ptr<ov::op::v5::Loop>& self) {
|
cls.def("get_function", [](const std::shared_ptr<ov::op::v5::Loop>& self) {
|
||||||
return self->get_function();
|
auto model = self->get_function();
|
||||||
|
py::type model_class = py::module_::import("openvino.runtime").attr("Model");
|
||||||
|
return model_class(py::cast(model));
|
||||||
});
|
});
|
||||||
|
|
||||||
cls.def(
|
cls.def(
|
||||||
|
@ -55,11 +55,15 @@ void regclass_graph_op_TensorIterator(py::module m) {
|
|||||||
py::arg("successive_value"));
|
py::arg("successive_value"));
|
||||||
|
|
||||||
cls.def("get_body", [](const std::shared_ptr<ov::op::v0::TensorIterator>& self) {
|
cls.def("get_body", [](const std::shared_ptr<ov::op::v0::TensorIterator>& self) {
|
||||||
return self->get_body();
|
auto model = self->get_body();
|
||||||
|
py::type model_class = py::module_::import("openvino.runtime").attr("Model");
|
||||||
|
return model_class(py::cast(model));
|
||||||
});
|
});
|
||||||
|
|
||||||
cls.def("get_function", [](const std::shared_ptr<ov::op::v0::TensorIterator>& self) {
|
cls.def("get_function", [](const std::shared_ptr<ov::op::v0::TensorIterator>& self) {
|
||||||
return self->get_function();
|
auto model = self->get_function();
|
||||||
|
py::type model_class = py::module_::import("openvino.runtime").attr("Model");
|
||||||
|
return model_class(py::cast(model));
|
||||||
});
|
});
|
||||||
|
|
||||||
cls.def(
|
cls.def(
|
||||||
|
@ -8,6 +8,8 @@ from openvino.runtime import Model
|
|||||||
|
|
||||||
from openvino.runtime.op.util import InvariantInputDescription, BodyOutputDescription
|
from openvino.runtime.op.util import InvariantInputDescription, BodyOutputDescription
|
||||||
|
|
||||||
|
from tests.utils.helpers import compare_models
|
||||||
|
|
||||||
|
|
||||||
def create_simple_if_with_two_outputs(condition_val):
|
def create_simple_if_with_two_outputs(condition_val):
|
||||||
condition = ov.constant(condition_val, dtype=bool)
|
condition = ov.constant(condition_val, dtype=bool)
|
||||||
@ -191,7 +193,11 @@ def test_simple_if_basic():
|
|||||||
|
|
||||||
if_node = ov.if_op(condition.output(0))
|
if_node = ov.if_op(condition.output(0))
|
||||||
if_node.set_function(0, then_body)
|
if_node.set_function(0, then_body)
|
||||||
assert if_node.get_function(0) == then_body
|
subgraph_func = if_node.get_function(0)
|
||||||
|
|
||||||
|
assert type(subgraph_func) == type(then_body)
|
||||||
|
assert compare_models(subgraph_func, then_body)
|
||||||
|
assert subgraph_func._get_raw_address() == then_body._get_raw_address()
|
||||||
|
|
||||||
if_node.set_input_descriptions(0, then_body_inputs)
|
if_node.set_input_descriptions(0, then_body_inputs)
|
||||||
if_node.set_output_descriptions(1, else_body_outputs)
|
if_node.set_output_descriptions(1, else_body_outputs)
|
||||||
|
@ -13,6 +13,7 @@ from openvino.runtime.op.util import (
|
|||||||
MergedInputDescription,
|
MergedInputDescription,
|
||||||
ConcatOutputDescription,
|
ConcatOutputDescription,
|
||||||
)
|
)
|
||||||
|
from tests.utils.helpers import compare_models
|
||||||
|
|
||||||
|
|
||||||
def test_simple_loop():
|
def test_simple_loop():
|
||||||
@ -139,7 +140,11 @@ def test_loop_basic():
|
|||||||
loop.get_iter_value(curr_cma.output(0), -1)
|
loop.get_iter_value(curr_cma.output(0), -1)
|
||||||
loop.get_concatenated_slices(cma_hist.output(0), 0, 1, 1, -1, 0)
|
loop.get_concatenated_slices(cma_hist.output(0), 0, 1, 1, -1, 0)
|
||||||
|
|
||||||
assert loop.get_function() == graph_body
|
subgraph_func = loop.get_function()
|
||||||
|
|
||||||
|
assert type(subgraph_func) == type(graph_body)
|
||||||
|
assert subgraph_func._get_raw_address() == graph_body._get_raw_address()
|
||||||
|
assert compare_models(subgraph_func, graph_body)
|
||||||
assert loop.get_special_body_ports() == body_ports
|
assert loop.get_special_body_ports() == body_ports
|
||||||
assert loop.get_num_iterations() == 16
|
assert loop.get_num_iterations() == 16
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ from openvino.runtime.op.util import (
|
|||||||
MergedInputDescription,
|
MergedInputDescription,
|
||||||
ConcatOutputDescription,
|
ConcatOutputDescription,
|
||||||
)
|
)
|
||||||
|
from tests.utils.helpers import compare_models
|
||||||
|
|
||||||
|
|
||||||
def test_simple_tensor_iterator():
|
def test_simple_tensor_iterator():
|
||||||
@ -118,7 +119,11 @@ def test_tensor_iterator_basic():
|
|||||||
ti.get_iter_value(curr_cma.output(0), -1)
|
ti.get_iter_value(curr_cma.output(0), -1)
|
||||||
ti.get_concatenated_slices(cma_hist.output(0), 0, 1, 1, -1, 0)
|
ti.get_concatenated_slices(cma_hist.output(0), 0, 1, 1, -1, 0)
|
||||||
|
|
||||||
assert ti.get_function() == graph_body
|
subgraph_func = ti.get_function()
|
||||||
|
|
||||||
|
assert type(subgraph_func) == type(graph_body)
|
||||||
|
assert compare_models(subgraph_func, graph_body)
|
||||||
|
assert subgraph_func._get_raw_address() == graph_body._get_raw_address()
|
||||||
assert ti.get_num_iterations() == 16
|
assert ti.get_num_iterations() == 16
|
||||||
|
|
||||||
input_desc = ti.get_input_descriptions()
|
input_desc = ti.get_input_descriptions()
|
||||||
|
Loading…
Reference in New Issue
Block a user