[nGraph] Matching names of functions (#1524)

This commit is contained in:
Jan Iwaszkiewicz 2020-07-30 13:25:42 +02:00 committed by GitHub
parent 66ebc76512
commit 0f62031991
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 31 additions and 10 deletions

View File

@ -45,8 +45,8 @@ void regclass_pyngraph_Function(py::module m)
function.def("get_parameters", &ngraph::Function::get_parameters); function.def("get_parameters", &ngraph::Function::get_parameters);
function.def("get_results", &ngraph::Function::get_results); function.def("get_results", &ngraph::Function::get_results);
function.def("get_result", &ngraph::Function::get_result); function.def("get_result", &ngraph::Function::get_result);
function.def("get_unique_name", &ngraph::Function::get_name); function.def("get_name", &ngraph::Function::get_name);
function.def("get_name", &ngraph::Function::get_friendly_name); function.def("get_friendly_name", &ngraph::Function::get_friendly_name);
function.def("set_friendly_name", &ngraph::Function::set_friendly_name); function.def("set_friendly_name", &ngraph::Function::set_friendly_name);
function.def("is_dynamic", &ngraph::Function::is_dynamic); function.def("is_dynamic", &ngraph::Function::is_dynamic);
function.def("__repr__", [](const ngraph::Function& self) { function.def("__repr__", [](const ngraph::Function& self) {
@ -99,4 +99,9 @@ void regclass_pyngraph_Function(py::module m)
return pybind_capsule; return pybind_capsule;
}); });
function.def_property_readonly("name", &ngraph::Function::get_name);
function.def_property("friendly_name",
&ngraph::Function::get_friendly_name,
&ngraph::Function::set_friendly_name);
} }

View File

@ -76,7 +76,7 @@ void regclass_pyngraph_Node(py::module m)
node.def("get_output_shape", &ngraph::Node::get_output_shape); node.def("get_output_shape", &ngraph::Node::get_output_shape);
node.def("get_output_partial_shape", &ngraph::Node::get_output_partial_shape); node.def("get_output_partial_shape", &ngraph::Node::get_output_partial_shape);
node.def("get_type_name", &ngraph::Node::get_type_name); node.def("get_type_name", &ngraph::Node::get_type_name);
node.def("get_unique_name", &ngraph::Node::get_name); node.def("get_name", &ngraph::Node::get_name);
node.def("get_friendly_name", &ngraph::Node::get_friendly_name); node.def("get_friendly_name", &ngraph::Node::get_friendly_name);
node.def("set_friendly_name", &ngraph::Node::set_friendly_name); node.def("set_friendly_name", &ngraph::Node::set_friendly_name);
node.def("input", (ngraph::Input<ngraph::Node>(ngraph::Node::*)(size_t)) & ngraph::Node::input); node.def("input", (ngraph::Input<ngraph::Node>(ngraph::Node::*)(size_t)) & ngraph::Node::input);

View File

@ -31,6 +31,12 @@ void regclass_pyngraph_op_util_ArithmeticReduction(py::module m)
// arithmeticReduction.def(py::init<const std::string&, // arithmeticReduction.def(py::init<const std::string&,
// const std::shared_ptr<ngraph::Node>&, // const std::shared_ptr<ngraph::Node>&,
// const ngraph::AxisSet& >()); // const ngraph::AxisSet& >());
arithmeticReduction.def_property_readonly( arithmeticReduction.def("get_reduction_axes",
"reduction_axes", &ngraph::op::util::ArithmeticReduction::get_reduction_axes); &ngraph::op::util::ArithmeticReduction::get_reduction_axes);
arithmeticReduction.def("set_reduction_axes",
&ngraph::op::util::ArithmeticReduction::set_reduction_axes);
arithmeticReduction.def_property("reduction_axes",
&ngraph::op::util::ArithmeticReduction::get_reduction_axes,
&ngraph::op::util::ArithmeticReduction::set_reduction_axes);
} }

View File

@ -27,8 +27,18 @@ void regclass_pyngraph_op_util_IndexReduction(py::module m)
{ {
py::class_<ngraph::op::util::IndexReduction, std::shared_ptr<ngraph::op::util::IndexReduction>> py::class_<ngraph::op::util::IndexReduction, std::shared_ptr<ngraph::op::util::IndexReduction>>
indexReduction(m, "IndexRedection"); indexReduction(m, "IndexRedection");
indexReduction.def_property_readonly("reduction_axis",
&ngraph::op::util::IndexReduction::get_reduction_axis); indexReduction.def("get_reduction_axis", &ngraph::op::util::IndexReduction::get_reduction_axis);
indexReduction.def_property_readonly("index_element_type", indexReduction.def("set_reduction_axis", &ngraph::op::util::IndexReduction::set_reduction_axis);
&ngraph::op::util::IndexReduction::get_index_element_type); indexReduction.def("get_index_element_type",
&ngraph::op::util::IndexReduction::get_index_element_type);
indexReduction.def("set_index_element_type",
&ngraph::op::util::IndexReduction::set_index_element_type);
indexReduction.def_property("reduction_axis",
&ngraph::op::util::IndexReduction::get_reduction_axis,
&ngraph::op::util::IndexReduction::set_reduction_axis);
indexReduction.def_property("index_element_type",
&ngraph::op::util::IndexReduction::get_index_element_type,
&ngraph::op::util::IndexReduction::set_index_element_type);
} }

View File

@ -43,7 +43,7 @@ def test_ngraph_function_api():
assert list(function.get_output_shape(0)) == [2, 2] assert list(function.get_output_shape(0)) == [2, 2]
assert len(function.get_parameters()) == 3 assert len(function.get_parameters()) == 3
assert len(function.get_results()) == 1 assert len(function.get_results()) == 1
assert function.get_name() == "TestFunction" assert function.get_friendly_name() == "TestFunction"
@pytest.mark.parametrize( @pytest.mark.parametrize(