diff --git a/src/bindings/python/src/pyopenvino/graph/model.cpp b/src/bindings/python/src/pyopenvino/graph/model.cpp index 3bb4995d27a..f666fcd8e76 100644 --- a/src/bindings/python/src/pyopenvino/graph/model.cpp +++ b/src/bindings/python/src/pyopenvino/graph/model.cpp @@ -998,6 +998,18 @@ void regclass_graph_Model(py::module m) { :type path: str )"); + model.def( + "_get_raw_address", + [](ov::Model& self) { + return reinterpret_cast(&self); + }, + R"( + Returns raw address of the Model object. + + :return: a raw address of the Model object. + :rtype: int + )"); + model.def_property_readonly("inputs", (std::vector>(ov::Model::*)()) & ov::Model::inputs); model.def_property_readonly("outputs", (std::vector>(ov::Model::*)()) & ov::Model::outputs); model.def_property_readonly("name", &ov::Model::get_name); diff --git a/src/bindings/python/src/pyopenvino/graph/ops/if.cpp b/src/bindings/python/src/pyopenvino/graph/ops/if.cpp index d1eb84eb014..ad83891ebaa 100644 --- a/src/bindings/python/src/pyopenvino/graph/ops/if.cpp +++ b/src/bindings/python/src/pyopenvino/graph/ops/if.cpp @@ -49,9 +49,14 @@ void regclass_graph_op_If(py::module m) { :rtype: openvino.impl.op.If )"); - cls.def("get_else_body", - &ov::op::v8::If::get_else_body, - R"( + cls.def( + "get_else_body", + [](ov::op::v8::If& self) { + auto model = self.get_else_body(); + py::type model_class = py::module_::import("openvino.runtime").attr("Model"); + return model_class(py::cast(model)); + }, + R"( Gets else_body as Model object. :return: else_body as Model object. @@ -119,10 +124,15 @@ void regclass_graph_op_If(py::module m) { :rtype: openvino.runtime.Output )"); - cls.def("get_function", - &ov::op::util::MultiSubGraphOp::get_function, - py::arg("index"), - R"( + cls.def( + "get_function", + [](ov::op::v8::If& self, size_t index) { + auto model = self.get_function(index); + py::type model_class = py::module_::import("openvino.runtime").attr("Model"); + return model_class(py::cast(model)); + }, + py::arg("index"), + R"( Gets internal sub-graph by index in MultiSubGraphOp. :param index: sub-graph's index in op. diff --git a/src/bindings/python/src/pyopenvino/graph/ops/loop.cpp b/src/bindings/python/src/pyopenvino/graph/ops/loop.cpp index c378c9cbe32..ebf317d3c46 100644 --- a/src/bindings/python/src/pyopenvino/graph/ops/loop.cpp +++ b/src/bindings/python/src/pyopenvino/graph/ops/loop.cpp @@ -92,7 +92,9 @@ void regclass_graph_op_Loop(py::module m) { py::arg("successive_value")); cls.def("get_function", [](const std::shared_ptr& self) { - return self->get_function(); + auto model = self->get_function(); + py::type model_class = py::module_::import("openvino.runtime").attr("Model"); + return model_class(py::cast(model)); }); cls.def( diff --git a/src/bindings/python/src/pyopenvino/graph/ops/tensor_iterator.cpp b/src/bindings/python/src/pyopenvino/graph/ops/tensor_iterator.cpp index 1a973fa04ed..1b2b159bb7e 100644 --- a/src/bindings/python/src/pyopenvino/graph/ops/tensor_iterator.cpp +++ b/src/bindings/python/src/pyopenvino/graph/ops/tensor_iterator.cpp @@ -55,11 +55,15 @@ void regclass_graph_op_TensorIterator(py::module m) { py::arg("successive_value")); cls.def("get_body", [](const std::shared_ptr& self) { - return self->get_body(); + auto model = self->get_body(); + py::type model_class = py::module_::import("openvino.runtime").attr("Model"); + return model_class(py::cast(model)); }); cls.def("get_function", [](const std::shared_ptr& self) { - return self->get_function(); + auto model = self->get_function(); + py::type model_class = py::module_::import("openvino.runtime").attr("Model"); + return model_class(py::cast(model)); }); cls.def( diff --git a/src/bindings/python/tests/test_graph/test_if.py b/src/bindings/python/tests/test_graph/test_if.py index eff446219c1..7e165342a2c 100644 --- a/src/bindings/python/tests/test_graph/test_if.py +++ b/src/bindings/python/tests/test_graph/test_if.py @@ -8,6 +8,8 @@ from openvino.runtime import Model from openvino.runtime.op.util import InvariantInputDescription, BodyOutputDescription +from tests.utils.helpers import compare_models + def create_simple_if_with_two_outputs(condition_val): condition = ov.constant(condition_val, dtype=bool) @@ -191,7 +193,11 @@ def test_simple_if_basic(): if_node = ov.if_op(condition.output(0)) if_node.set_function(0, then_body) - assert if_node.get_function(0) == then_body + subgraph_func = if_node.get_function(0) + + assert type(subgraph_func) == type(then_body) + assert compare_models(subgraph_func, then_body) + assert subgraph_func._get_raw_address() == then_body._get_raw_address() if_node.set_input_descriptions(0, then_body_inputs) if_node.set_output_descriptions(1, else_body_outputs) diff --git a/src/bindings/python/tests/test_graph/test_loop.py b/src/bindings/python/tests/test_graph/test_loop.py index c30cfad1103..9a2fb6fcaf6 100644 --- a/src/bindings/python/tests/test_graph/test_loop.py +++ b/src/bindings/python/tests/test_graph/test_loop.py @@ -13,6 +13,7 @@ from openvino.runtime.op.util import ( MergedInputDescription, ConcatOutputDescription, ) +from tests.utils.helpers import compare_models def test_simple_loop(): @@ -139,7 +140,11 @@ def test_loop_basic(): loop.get_iter_value(curr_cma.output(0), -1) loop.get_concatenated_slices(cma_hist.output(0), 0, 1, 1, -1, 0) - assert loop.get_function() == graph_body + subgraph_func = loop.get_function() + + assert type(subgraph_func) == type(graph_body) + assert subgraph_func._get_raw_address() == graph_body._get_raw_address() + assert compare_models(subgraph_func, graph_body) assert loop.get_special_body_ports() == body_ports assert loop.get_num_iterations() == 16 diff --git a/src/bindings/python/tests/test_graph/test_tensor_iterator.py b/src/bindings/python/tests/test_graph/test_tensor_iterator.py index 36c50713f6a..dd58b3da3f4 100644 --- a/src/bindings/python/tests/test_graph/test_tensor_iterator.py +++ b/src/bindings/python/tests/test_graph/test_tensor_iterator.py @@ -13,6 +13,7 @@ from openvino.runtime.op.util import ( MergedInputDescription, ConcatOutputDescription, ) +from tests.utils.helpers import compare_models def test_simple_tensor_iterator(): @@ -118,7 +119,11 @@ def test_tensor_iterator_basic(): ti.get_iter_value(curr_cma.output(0), -1) ti.get_concatenated_slices(cma_hist.output(0), 0, 1, 1, -1, 0) - assert ti.get_function() == graph_body + subgraph_func = ti.get_function() + + assert type(subgraph_func) == type(graph_body) + assert compare_models(subgraph_func, graph_body) + assert subgraph_func._get_raw_address() == graph_body._get_raw_address() assert ti.get_num_iterations() == 16 input_desc = ti.get_input_descriptions()