[PyOV] fix if_op return types (#20014)
* [PyOV] fix if_op return types * fix if * other types * test for if * add tests, try fix loop, ti * fix ci
This commit is contained in:
parent
ca7031f69d
commit
c90bcbf5eb
@ -998,6 +998,18 @@ void regclass_graph_Model(py::module m) {
|
||||
:type path: str
|
||||
)");
|
||||
|
||||
model.def(
|
||||
"_get_raw_address",
|
||||
[](ov::Model& self) {
|
||||
return reinterpret_cast<uint64_t>(&self);
|
||||
},
|
||||
R"(
|
||||
Returns raw address of the Model object.
|
||||
|
||||
:return: a raw address of the Model object.
|
||||
:rtype: int
|
||||
)");
|
||||
|
||||
model.def_property_readonly("inputs", (std::vector<ov::Output<ov::Node>>(ov::Model::*)()) & ov::Model::inputs);
|
||||
model.def_property_readonly("outputs", (std::vector<ov::Output<ov::Node>>(ov::Model::*)()) & ov::Model::outputs);
|
||||
model.def_property_readonly("name", &ov::Model::get_name);
|
||||
|
@ -49,8 +49,13 @@ void regclass_graph_op_If(py::module m) {
|
||||
:rtype: openvino.impl.op.If
|
||||
)");
|
||||
|
||||
cls.def("get_else_body",
|
||||
&ov::op::v8::If::get_else_body,
|
||||
cls.def(
|
||||
"get_else_body",
|
||||
[](ov::op::v8::If& self) {
|
||||
auto model = self.get_else_body();
|
||||
py::type model_class = py::module_::import("openvino.runtime").attr("Model");
|
||||
return model_class(py::cast(model));
|
||||
},
|
||||
R"(
|
||||
Gets else_body as Model object.
|
||||
|
||||
@ -119,8 +124,13 @@ void regclass_graph_op_If(py::module m) {
|
||||
:rtype: openvino.runtime.Output
|
||||
)");
|
||||
|
||||
cls.def("get_function",
|
||||
&ov::op::util::MultiSubGraphOp::get_function,
|
||||
cls.def(
|
||||
"get_function",
|
||||
[](ov::op::v8::If& self, size_t index) {
|
||||
auto model = self.get_function(index);
|
||||
py::type model_class = py::module_::import("openvino.runtime").attr("Model");
|
||||
return model_class(py::cast(model));
|
||||
},
|
||||
py::arg("index"),
|
||||
R"(
|
||||
Gets internal sub-graph by index in MultiSubGraphOp.
|
||||
|
@ -92,7 +92,9 @@ void regclass_graph_op_Loop(py::module m) {
|
||||
py::arg("successive_value"));
|
||||
|
||||
cls.def("get_function", [](const std::shared_ptr<ov::op::v5::Loop>& self) {
|
||||
return self->get_function();
|
||||
auto model = self->get_function();
|
||||
py::type model_class = py::module_::import("openvino.runtime").attr("Model");
|
||||
return model_class(py::cast(model));
|
||||
});
|
||||
|
||||
cls.def(
|
||||
|
@ -55,11 +55,15 @@ void regclass_graph_op_TensorIterator(py::module m) {
|
||||
py::arg("successive_value"));
|
||||
|
||||
cls.def("get_body", [](const std::shared_ptr<ov::op::v0::TensorIterator>& self) {
|
||||
return self->get_body();
|
||||
auto model = self->get_body();
|
||||
py::type model_class = py::module_::import("openvino.runtime").attr("Model");
|
||||
return model_class(py::cast(model));
|
||||
});
|
||||
|
||||
cls.def("get_function", [](const std::shared_ptr<ov::op::v0::TensorIterator>& self) {
|
||||
return self->get_function();
|
||||
auto model = self->get_function();
|
||||
py::type model_class = py::module_::import("openvino.runtime").attr("Model");
|
||||
return model_class(py::cast(model));
|
||||
});
|
||||
|
||||
cls.def(
|
||||
|
@ -8,6 +8,8 @@ from openvino.runtime import Model
|
||||
|
||||
from openvino.runtime.op.util import InvariantInputDescription, BodyOutputDescription
|
||||
|
||||
from tests.utils.helpers import compare_models
|
||||
|
||||
|
||||
def create_simple_if_with_two_outputs(condition_val):
|
||||
condition = ov.constant(condition_val, dtype=bool)
|
||||
@ -191,7 +193,11 @@ def test_simple_if_basic():
|
||||
|
||||
if_node = ov.if_op(condition.output(0))
|
||||
if_node.set_function(0, then_body)
|
||||
assert if_node.get_function(0) == then_body
|
||||
subgraph_func = if_node.get_function(0)
|
||||
|
||||
assert type(subgraph_func) == type(then_body)
|
||||
assert compare_models(subgraph_func, then_body)
|
||||
assert subgraph_func._get_raw_address() == then_body._get_raw_address()
|
||||
|
||||
if_node.set_input_descriptions(0, then_body_inputs)
|
||||
if_node.set_output_descriptions(1, else_body_outputs)
|
||||
|
@ -13,6 +13,7 @@ from openvino.runtime.op.util import (
|
||||
MergedInputDescription,
|
||||
ConcatOutputDescription,
|
||||
)
|
||||
from tests.utils.helpers import compare_models
|
||||
|
||||
|
||||
def test_simple_loop():
|
||||
@ -139,7 +140,11 @@ def test_loop_basic():
|
||||
loop.get_iter_value(curr_cma.output(0), -1)
|
||||
loop.get_concatenated_slices(cma_hist.output(0), 0, 1, 1, -1, 0)
|
||||
|
||||
assert loop.get_function() == graph_body
|
||||
subgraph_func = loop.get_function()
|
||||
|
||||
assert type(subgraph_func) == type(graph_body)
|
||||
assert subgraph_func._get_raw_address() == graph_body._get_raw_address()
|
||||
assert compare_models(subgraph_func, graph_body)
|
||||
assert loop.get_special_body_ports() == body_ports
|
||||
assert loop.get_num_iterations() == 16
|
||||
|
||||
|
@ -13,6 +13,7 @@ from openvino.runtime.op.util import (
|
||||
MergedInputDescription,
|
||||
ConcatOutputDescription,
|
||||
)
|
||||
from tests.utils.helpers import compare_models
|
||||
|
||||
|
||||
def test_simple_tensor_iterator():
|
||||
@ -118,7 +119,11 @@ def test_tensor_iterator_basic():
|
||||
ti.get_iter_value(curr_cma.output(0), -1)
|
||||
ti.get_concatenated_slices(cma_hist.output(0), 0, 1, 1, -1, 0)
|
||||
|
||||
assert ti.get_function() == graph_body
|
||||
subgraph_func = ti.get_function()
|
||||
|
||||
assert type(subgraph_func) == type(graph_body)
|
||||
assert compare_models(subgraph_func, graph_body)
|
||||
assert subgraph_func._get_raw_address() == graph_body._get_raw_address()
|
||||
assert ti.get_num_iterations() == 16
|
||||
|
||||
input_desc = ti.get_input_descriptions()
|
||||
|
Loading…
Reference in New Issue
Block a user