[Python API] Add more api for parameter and result ops (#8691)
* [Python API] Add more api for parameter and result ops * add tests * add api for result * add py::args * apply comments to test
This commit is contained in:
parent
6addc0d535
commit
3cc0517492
@ -34,5 +34,23 @@ void regclass_graph_op_Parameter(py::module m) {
|
||||
(const ov::PartialShape& (ov::op::v0::Parameter::*)() const) & ov::op::v0::Parameter::get_partial_shape);
|
||||
parameter.def("get_partial_shape",
|
||||
(ov::PartialShape & (ov::op::v0::Parameter::*)()) & ov::op::v0::Parameter::get_partial_shape);
|
||||
parameter.def("set_partial_shape", &ov::op::v0::Parameter::set_partial_shape);
|
||||
parameter.def("set_partial_shape", &ov::op::v0::Parameter::set_partial_shape, py::arg("partial_shape"));
|
||||
|
||||
parameter.def("get_element_type", &ov::op::v0::Parameter::get_element_type);
|
||||
|
||||
parameter.def("set_element_type", &ov::op::v0::Parameter::set_element_type, py::arg("element_type"));
|
||||
|
||||
parameter.def("get_layout", &ov::op::v0::Parameter::get_layout);
|
||||
|
||||
parameter.def("set_layout", &ov::op::v0::Parameter::set_layout, py::arg("layout"));
|
||||
|
||||
parameter.def_property("partial_shape",
|
||||
(ov::PartialShape & (ov::op::v0::Parameter::*)()) & ov::op::v0::Parameter::get_partial_shape,
|
||||
&ov::op::v0::Parameter::set_partial_shape);
|
||||
|
||||
parameter.def_property("element_type",
|
||||
&ov::op::v0::Parameter::get_element_type,
|
||||
&ov::op::v0::Parameter::set_element_type);
|
||||
|
||||
parameter.def_property("layout", &ov::op::v0::Parameter::get_layout, &ov::op::v0::Parameter::set_layout);
|
||||
}
|
||||
|
@ -18,4 +18,14 @@ void regclass_graph_op_Result(py::module m) {
|
||||
py::class_<ov::op::v0::Result, std::shared_ptr<ov::op::v0::Result>, ov::Node> result(m, "Result");
|
||||
|
||||
result.doc() = "openvino.impl.op.Result wraps ov::op::v0::Result";
|
||||
|
||||
result.def("get_output_partial_shape", &ov::Node::get_output_partial_shape, py::arg("index"));
|
||||
|
||||
result.def("get_output_element_type", &ov::Node::get_output_element_type, py::arg("index"));
|
||||
|
||||
result.def("get_layout", &ov::op::v0::Result::get_layout);
|
||||
|
||||
result.def("set_layout", &ov::op::v0::Result::set_layout, py::arg("layout"));
|
||||
|
||||
result.def_property("layout", &ov::op::v0::Result::get_layout, &ov::op::v0::Result::set_layout);
|
||||
}
|
||||
|
@ -24,6 +24,11 @@ def test_ngraph_function_api():
|
||||
parameter_b = ops.parameter(shape, dtype=np.float32, name="B")
|
||||
parameter_c = ops.parameter(shape, dtype=np.float32, name="C")
|
||||
model = (parameter_a + parameter_b) * parameter_c
|
||||
|
||||
assert parameter_a.element_type == Type.f32
|
||||
assert parameter_a.partial_shape == PartialShape([2, 2])
|
||||
parameter_a.layout = ov.Layout("NCWH")
|
||||
assert parameter_a.layout == ov.Layout("NCWH")
|
||||
function = Function(model, [parameter_a, parameter_b, parameter_c], "TestFunction")
|
||||
|
||||
function.get_parameters()[1].set_partial_shape(PartialShape([3, 4, 5]))
|
||||
@ -44,7 +49,12 @@ def test_ngraph_function_api():
|
||||
assert list(function.get_output_shape(0)) == [2, 2]
|
||||
assert (function.get_parameters()[1].get_partial_shape()) == PartialShape([3, 4, 5])
|
||||
assert len(function.get_parameters()) == 3
|
||||
assert len(function.get_results()) == 1
|
||||
results = function.get_results()
|
||||
assert len(results) == 1
|
||||
assert results[0].get_output_element_type(0) == Type.f32
|
||||
assert results[0].get_output_partial_shape(0) == PartialShape([2, 2])
|
||||
results[0].layout = ov.Layout("NC")
|
||||
assert results[0].layout.to_string() == ov.Layout("NC")
|
||||
assert function.get_friendly_name() == "TestFunction"
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user