From 3cc0517492d9df20c4c8df0568ce4ec53c56db72 Mon Sep 17 00:00:00 2001 From: Anastasia Kuporosova Date: Tue, 23 Nov 2021 23:03:03 +0300 Subject: [PATCH] [Python API] Add more api for parameter and result ops (#8691) * [Python API] Add more api for parameter and result ops * add tests * add api for result * add py::args * apply comments to test --- .../src/pyopenvino/graph/ops/parameter.cpp | 20 ++++++++++++++++++- .../src/pyopenvino/graph/ops/result.cpp | 10 ++++++++++ .../python/tests/test_ngraph/test_basic.py | 12 ++++++++++- 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/runtime/bindings/python/src/pyopenvino/graph/ops/parameter.cpp b/runtime/bindings/python/src/pyopenvino/graph/ops/parameter.cpp index 3eb8997f432..33f779381ff 100644 --- a/runtime/bindings/python/src/pyopenvino/graph/ops/parameter.cpp +++ b/runtime/bindings/python/src/pyopenvino/graph/ops/parameter.cpp @@ -34,5 +34,23 @@ void regclass_graph_op_Parameter(py::module m) { (const ov::PartialShape& (ov::op::v0::Parameter::*)() const) & ov::op::v0::Parameter::get_partial_shape); parameter.def("get_partial_shape", (ov::PartialShape & (ov::op::v0::Parameter::*)()) & ov::op::v0::Parameter::get_partial_shape); - parameter.def("set_partial_shape", &ov::op::v0::Parameter::set_partial_shape); + parameter.def("set_partial_shape", &ov::op::v0::Parameter::set_partial_shape, py::arg("partial_shape")); + + parameter.def("get_element_type", &ov::op::v0::Parameter::get_element_type); + + parameter.def("set_element_type", &ov::op::v0::Parameter::set_element_type, py::arg("element_type")); + + parameter.def("get_layout", &ov::op::v0::Parameter::get_layout); + + parameter.def("set_layout", &ov::op::v0::Parameter::set_layout, py::arg("layout")); + + parameter.def_property("partial_shape", + (ov::PartialShape & (ov::op::v0::Parameter::*)()) & ov::op::v0::Parameter::get_partial_shape, + &ov::op::v0::Parameter::set_partial_shape); + + parameter.def_property("element_type", + &ov::op::v0::Parameter::get_element_type, + &ov::op::v0::Parameter::set_element_type); + + parameter.def_property("layout", &ov::op::v0::Parameter::get_layout, &ov::op::v0::Parameter::set_layout); } diff --git a/runtime/bindings/python/src/pyopenvino/graph/ops/result.cpp b/runtime/bindings/python/src/pyopenvino/graph/ops/result.cpp index fb917d71d79..5515cbee968 100644 --- a/runtime/bindings/python/src/pyopenvino/graph/ops/result.cpp +++ b/runtime/bindings/python/src/pyopenvino/graph/ops/result.cpp @@ -18,4 +18,14 @@ void regclass_graph_op_Result(py::module m) { py::class_, ov::Node> result(m, "Result"); result.doc() = "openvino.impl.op.Result wraps ov::op::v0::Result"; + + result.def("get_output_partial_shape", &ov::Node::get_output_partial_shape, py::arg("index")); + + result.def("get_output_element_type", &ov::Node::get_output_element_type, py::arg("index")); + + result.def("get_layout", &ov::op::v0::Result::get_layout); + + result.def("set_layout", &ov::op::v0::Result::set_layout, py::arg("layout")); + + result.def_property("layout", &ov::op::v0::Result::get_layout, &ov::op::v0::Result::set_layout); } diff --git a/runtime/bindings/python/tests/test_ngraph/test_basic.py b/runtime/bindings/python/tests/test_ngraph/test_basic.py index 011abee486e..47b1859de9f 100644 --- a/runtime/bindings/python/tests/test_ngraph/test_basic.py +++ b/runtime/bindings/python/tests/test_ngraph/test_basic.py @@ -24,6 +24,11 @@ def test_ngraph_function_api(): parameter_b = ops.parameter(shape, dtype=np.float32, name="B") parameter_c = ops.parameter(shape, dtype=np.float32, name="C") model = (parameter_a + parameter_b) * parameter_c + + assert parameter_a.element_type == Type.f32 + assert parameter_a.partial_shape == PartialShape([2, 2]) + parameter_a.layout = ov.Layout("NCWH") + assert parameter_a.layout == ov.Layout("NCWH") function = Function(model, [parameter_a, parameter_b, parameter_c], "TestFunction") function.get_parameters()[1].set_partial_shape(PartialShape([3, 4, 5])) @@ -44,7 +49,12 @@ def test_ngraph_function_api(): assert list(function.get_output_shape(0)) == [2, 2] assert (function.get_parameters()[1].get_partial_shape()) == PartialShape([3, 4, 5]) assert len(function.get_parameters()) == 3 - assert len(function.get_results()) == 1 + results = function.get_results() + assert len(results) == 1 + assert results[0].get_output_element_type(0) == Type.f32 + assert results[0].get_output_partial_shape(0) == PartialShape([2, 2]) + results[0].layout = ov.Layout("NC") + assert results[0].layout.to_string() == ov.Layout("NC") assert function.get_friendly_name() == "TestFunction"