[nGraph] Py API get/set partial shape of parameter (#1560)

This commit is contained in:
Jan Iwaszkiewicz 2020-07-31 10:14:39 +02:00 committed by GitHub
parent c5bac5a1b9
commit 43652498c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 0 deletions

View File

@ -42,4 +42,12 @@ void regclass_pyngraph_op_Parameter(py::module m)
parameter.def(py::init<const ngraph::element::Type&, const ngraph::Shape&>()); parameter.def(py::init<const ngraph::element::Type&, const ngraph::Shape&>());
parameter.def(py::init<const ngraph::element::Type&, const ngraph::PartialShape&>()); parameter.def(py::init<const ngraph::element::Type&, const ngraph::PartialShape&>());
// parameter.def_property_readonly("description", &ngraph::op::Parameter::description); // parameter.def_property_readonly("description", &ngraph::op::Parameter::description);
parameter.def("get_partial_shape",
(const ngraph::PartialShape& (ngraph::op::Parameter::*)() const) &
ngraph::op::Parameter::get_partial_shape);
parameter.def("get_partial_shape",
(ngraph::PartialShape & (ngraph::op::Parameter::*)()) &
ngraph::op::Parameter::get_partial_shape);
parameter.def("set_partial_shape", &ngraph::op::Parameter::set_partial_shape);
} }

View File

@ -33,6 +33,8 @@ def test_ngraph_function_api():
model = (parameter_a + parameter_b) * parameter_c model = (parameter_a + parameter_b) * parameter_c
function = Function(model, [parameter_a, parameter_b, parameter_c], "TestFunction") function = Function(model, [parameter_a, parameter_b, parameter_c], "TestFunction")
function.get_parameters()[1].set_partial_shape(PartialShape([3, 4, 5]))
ordered_ops = function.get_ordered_ops() ordered_ops = function.get_ordered_ops()
op_types = [op.get_type_name() for op in ordered_ops] op_types = [op.get_type_name() for op in ordered_ops]
assert op_types == ["Parameter", "Parameter", "Parameter", "Add", "Multiply", "Result"] assert op_types == ["Parameter", "Parameter", "Parameter", "Add", "Multiply", "Result"]
@ -41,6 +43,7 @@ def test_ngraph_function_api():
assert function.get_output_op(0).get_type_name() == "Result" assert function.get_output_op(0).get_type_name() == "Result"
assert function.get_output_element_type(0) == parameter_a.get_element_type() assert function.get_output_element_type(0) == parameter_a.get_element_type()
assert list(function.get_output_shape(0)) == [2, 2] assert list(function.get_output_shape(0)) == [2, 2]
assert (function.get_parameters()[1].get_partial_shape()) == PartialShape([3, 4, 5])
assert len(function.get_parameters()) == 3 assert len(function.get_parameters()) == 3
assert len(function.get_results()) == 1 assert len(function.get_results()) == 1
assert function.get_friendly_name() == "TestFunction" assert function.get_friendly_name() == "TestFunction"