[Python API] Add more api for parameter and result ops (#8691)

* [Python API] Add more api for parameter and result ops

* add tests

* add api for result

* add py::args

* apply comments to test
This commit is contained in:
Anastasia Kuporosova 2021-11-23 23:03:03 +03:00 committed by GitHub
parent 6addc0d535
commit 3cc0517492
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 2 deletions

View File

@ -34,5 +34,23 @@ void regclass_graph_op_Parameter(py::module m) {
(const ov::PartialShape& (ov::op::v0::Parameter::*)() const) & ov::op::v0::Parameter::get_partial_shape);
parameter.def("get_partial_shape",
(ov::PartialShape & (ov::op::v0::Parameter::*)()) & ov::op::v0::Parameter::get_partial_shape);
parameter.def("set_partial_shape", &ov::op::v0::Parameter::set_partial_shape);
parameter.def("set_partial_shape", &ov::op::v0::Parameter::set_partial_shape, py::arg("partial_shape"));
parameter.def("get_element_type", &ov::op::v0::Parameter::get_element_type);
parameter.def("set_element_type", &ov::op::v0::Parameter::set_element_type, py::arg("element_type"));
parameter.def("get_layout", &ov::op::v0::Parameter::get_layout);
parameter.def("set_layout", &ov::op::v0::Parameter::set_layout, py::arg("layout"));
parameter.def_property("partial_shape",
(ov::PartialShape & (ov::op::v0::Parameter::*)()) & ov::op::v0::Parameter::get_partial_shape,
&ov::op::v0::Parameter::set_partial_shape);
parameter.def_property("element_type",
&ov::op::v0::Parameter::get_element_type,
&ov::op::v0::Parameter::set_element_type);
parameter.def_property("layout", &ov::op::v0::Parameter::get_layout, &ov::op::v0::Parameter::set_layout);
}

View File

@ -18,4 +18,14 @@ void regclass_graph_op_Result(py::module m) {
py::class_<ov::op::v0::Result, std::shared_ptr<ov::op::v0::Result>, ov::Node> result(m, "Result");
result.doc() = "openvino.impl.op.Result wraps ov::op::v0::Result";
result.def("get_output_partial_shape", &ov::Node::get_output_partial_shape, py::arg("index"));
result.def("get_output_element_type", &ov::Node::get_output_element_type, py::arg("index"));
result.def("get_layout", &ov::op::v0::Result::get_layout);
result.def("set_layout", &ov::op::v0::Result::set_layout, py::arg("layout"));
result.def_property("layout", &ov::op::v0::Result::get_layout, &ov::op::v0::Result::set_layout);
}

View File

@ -24,6 +24,11 @@ def test_ngraph_function_api():
parameter_b = ops.parameter(shape, dtype=np.float32, name="B")
parameter_c = ops.parameter(shape, dtype=np.float32, name="C")
model = (parameter_a + parameter_b) * parameter_c
assert parameter_a.element_type == Type.f32
assert parameter_a.partial_shape == PartialShape([2, 2])
parameter_a.layout = ov.Layout("NCWH")
assert parameter_a.layout == ov.Layout("NCWH")
function = Function(model, [parameter_a, parameter_b, parameter_c], "TestFunction")
function.get_parameters()[1].set_partial_shape(PartialShape([3, 4, 5]))
@ -44,7 +49,12 @@ def test_ngraph_function_api():
assert list(function.get_output_shape(0)) == [2, 2]
assert (function.get_parameters()[1].get_partial_shape()) == PartialShape([3, 4, 5])
assert len(function.get_parameters()) == 3
assert len(function.get_results()) == 1
results = function.get_results()
assert len(results) == 1
assert results[0].get_output_element_type(0) == Type.f32
assert results[0].get_output_partial_shape(0) == PartialShape([2, 2])
results[0].layout = ov.Layout("NC")
assert results[0].layout.to_string() == ov.Layout("NC")
assert function.get_friendly_name() == "TestFunction"