[nGraph] Py API get/set partial shape of parameter (#1560)
This commit is contained in:
parent
c5bac5a1b9
commit
43652498c7
@ -42,4 +42,12 @@ void regclass_pyngraph_op_Parameter(py::module m)
|
||||
parameter.def(py::init<const ngraph::element::Type&, const ngraph::Shape&>());
|
||||
parameter.def(py::init<const ngraph::element::Type&, const ngraph::PartialShape&>());
|
||||
// parameter.def_property_readonly("description", &ngraph::op::Parameter::description);
|
||||
|
||||
parameter.def("get_partial_shape",
|
||||
(const ngraph::PartialShape& (ngraph::op::Parameter::*)() const) &
|
||||
ngraph::op::Parameter::get_partial_shape);
|
||||
parameter.def("get_partial_shape",
|
||||
(ngraph::PartialShape & (ngraph::op::Parameter::*)()) &
|
||||
ngraph::op::Parameter::get_partial_shape);
|
||||
parameter.def("set_partial_shape", &ngraph::op::Parameter::set_partial_shape);
|
||||
}
|
||||
|
@ -33,6 +33,8 @@ def test_ngraph_function_api():
|
||||
model = (parameter_a + parameter_b) * parameter_c
|
||||
function = Function(model, [parameter_a, parameter_b, parameter_c], "TestFunction")
|
||||
|
||||
function.get_parameters()[1].set_partial_shape(PartialShape([3, 4, 5]))
|
||||
|
||||
ordered_ops = function.get_ordered_ops()
|
||||
op_types = [op.get_type_name() for op in ordered_ops]
|
||||
assert op_types == ["Parameter", "Parameter", "Parameter", "Add", "Multiply", "Result"]
|
||||
@ -41,6 +43,7 @@ def test_ngraph_function_api():
|
||||
assert function.get_output_op(0).get_type_name() == "Result"
|
||||
assert function.get_output_element_type(0) == parameter_a.get_element_type()
|
||||
assert list(function.get_output_shape(0)) == [2, 2]
|
||||
assert (function.get_parameters()[1].get_partial_shape()) == PartialShape([3, 4, 5])
|
||||
assert len(function.get_parameters()) == 3
|
||||
assert len(function.get_results()) == 1
|
||||
assert function.get_friendly_name() == "TestFunction"
|
||||
|
Loading…
Reference in New Issue
Block a user