diff --git a/ngraph/python/src/pyngraph/ops/parameter.cpp b/ngraph/python/src/pyngraph/ops/parameter.cpp index 1290157376b..90dcc3b1625 100644 --- a/ngraph/python/src/pyngraph/ops/parameter.cpp +++ b/ngraph/python/src/pyngraph/ops/parameter.cpp @@ -42,4 +42,12 @@ void regclass_pyngraph_op_Parameter(py::module m) parameter.def(py::init()); parameter.def(py::init()); // parameter.def_property_readonly("description", &ngraph::op::Parameter::description); + + parameter.def("get_partial_shape", + (const ngraph::PartialShape& (ngraph::op::Parameter::*)() const) & + ngraph::op::Parameter::get_partial_shape); + parameter.def("get_partial_shape", + (ngraph::PartialShape & (ngraph::op::Parameter::*)()) & + ngraph::op::Parameter::get_partial_shape); + parameter.def("set_partial_shape", &ngraph::op::Parameter::set_partial_shape); } diff --git a/ngraph/python/tests/test_ngraph/test_basic.py b/ngraph/python/tests/test_ngraph/test_basic.py index c25635f30b8..5821dc3e1ef 100644 --- a/ngraph/python/tests/test_ngraph/test_basic.py +++ b/ngraph/python/tests/test_ngraph/test_basic.py @@ -33,6 +33,8 @@ def test_ngraph_function_api(): model = (parameter_a + parameter_b) * parameter_c function = Function(model, [parameter_a, parameter_b, parameter_c], "TestFunction") + function.get_parameters()[1].set_partial_shape(PartialShape([3, 4, 5])) + ordered_ops = function.get_ordered_ops() op_types = [op.get_type_name() for op in ordered_ops] assert op_types == ["Parameter", "Parameter", "Parameter", "Add", "Multiply", "Result"] @@ -41,6 +43,7 @@ def test_ngraph_function_api(): assert function.get_output_op(0).get_type_name() == "Result" assert function.get_output_element_type(0) == parameter_a.get_element_type() assert list(function.get_output_shape(0)) == [2, 2] + assert (function.get_parameters()[1].get_partial_shape()) == PartialShape([3, 4, 5]) assert len(function.get_parameters()) == 3 assert len(function.get_results()) == 1 assert function.get_friendly_name() == "TestFunction"