From db09547087effbe72abed97d50eace523c80823d Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Wed, 15 Jul 2020 15:32:24 +0200 Subject: [PATCH] Add Input and Output class to Py API (#1284) --- ngraph/python/setup.py | 2 + ngraph/python/src/ngraph/impl/__init__.py | 2 + ngraph/python/src/pyngraph/node.cpp | 8 ++ ngraph/python/src/pyngraph/node_input.cpp | 37 ++++++++ ngraph/python/src/pyngraph/node_input.hpp | 23 +++++ ngraph/python/src/pyngraph/node_output.cpp | 37 ++++++++ ngraph/python/src/pyngraph/node_output.hpp | 23 +++++ ngraph/python/src/pyngraph/pyngraph.cpp | 4 + ngraph/python/tests/test_ngraph/test_basic.py | 88 ++++++++++++++++++- 9 files changed, 223 insertions(+), 1 deletion(-) create mode 100644 ngraph/python/src/pyngraph/node_input.cpp create mode 100644 ngraph/python/src/pyngraph/node_input.hpp create mode 100644 ngraph/python/src/pyngraph/node_output.cpp create mode 100644 ngraph/python/src/pyngraph/node_output.hpp diff --git a/ngraph/python/setup.py b/ngraph/python/setup.py index 7b32b945ddb..132486c695d 100644 --- a/ngraph/python/setup.py +++ b/ngraph/python/setup.py @@ -185,6 +185,8 @@ sources = [ "pyngraph/dimension.cpp", "pyngraph/function.cpp", "pyngraph/node.cpp", + "pyngraph/node_input.cpp", + "pyngraph/node_output.cpp", "pyngraph/node_factory.cpp", "pyngraph/ops/constant.cpp", "pyngraph/ops/get_output_element.cpp", diff --git a/ngraph/python/src/ngraph/impl/__init__.py b/ngraph/python/src/ngraph/impl/__init__.py index fd9f574b913..314fcf0c970 100644 --- a/ngraph/python/src/ngraph/impl/__init__.py +++ b/ngraph/python/src/ngraph/impl/__init__.py @@ -34,6 +34,8 @@ if sys.platform == "win32": from _pyngraph import Dimension from _pyngraph import Function +from _pyngraph import Input +from _pyngraph import Output from _pyngraph import Node from _pyngraph import Type from _pyngraph import PartialShape diff --git a/ngraph/python/src/pyngraph/node.cpp b/ngraph/python/src/pyngraph/node.cpp index ba7decb4baf..a6b70584059 100644 --- a/ngraph/python/src/pyngraph/node.cpp +++ b/ngraph/python/src/pyngraph/node.cpp @@ -77,6 +77,14 @@ void regclass_pyngraph_Node(py::module m) node.def("get_output_partial_shape", &ngraph::Node::get_output_partial_shape); node.def("get_type_name", &ngraph::Node::get_type_name); node.def("get_unique_name", &ngraph::Node::get_name); + node.def("input", (ngraph::Input(ngraph::Node::*)(size_t)) & ngraph::Node::input); + node.def("inputs", + (std::vector>(ngraph::Node::*)()) & ngraph::Node::inputs); + node.def("output", + (ngraph::Output(ngraph::Node::*)(size_t)) & ngraph::Node::output); + node.def("outputs", + (std::vector>(ngraph::Node::*)()) & + ngraph::Node::outputs); node.def_property("name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name); node.def_property_readonly("shape", &ngraph::Node::get_shape); diff --git a/ngraph/python/src/pyngraph/node_input.cpp b/ngraph/python/src/pyngraph/node_input.cpp new file mode 100644 index 00000000000..344f4da6748 --- /dev/null +++ b/ngraph/python/src/pyngraph/node_input.cpp @@ -0,0 +1,37 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include + +#include "dict_attribute_visitor.hpp" +#include "ngraph/node_input.hpp" +#include "pyngraph/node_input.hpp" + +namespace py = pybind11; + +void regclass_pyngraph_Input(py::module m) +{ + py::class_, std::shared_ptr>> input( + m, "Input", py::dynamic_attr()); + input.doc() = "ngraph.impl.Input wraps ngraph::Input"; + + input.def("get_node", &ngraph::Input::get_node); + input.def("get_index", &ngraph::Input::get_index); + input.def("get_element_type", &ngraph::Input::get_element_type); + input.def("get_shape", &ngraph::Input::get_shape); + input.def("get_partial_shape", &ngraph::Input::get_partial_shape); + input.def("get_source_output", &ngraph::Input::get_source_output); +} diff --git a/ngraph/python/src/pyngraph/node_input.hpp b/ngraph/python/src/pyngraph/node_input.hpp new file mode 100644 index 00000000000..77571b4397b --- /dev/null +++ b/ngraph/python/src/pyngraph/node_input.hpp @@ -0,0 +1,23 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include + +namespace py = pybind11; + +void regclass_pyngraph_Input(py::module m); diff --git a/ngraph/python/src/pyngraph/node_output.cpp b/ngraph/python/src/pyngraph/node_output.cpp new file mode 100644 index 00000000000..b2544496dcc --- /dev/null +++ b/ngraph/python/src/pyngraph/node_output.cpp @@ -0,0 +1,37 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include + +#include "dict_attribute_visitor.hpp" +#include "ngraph/node_output.hpp" +#include "pyngraph/node_output.hpp" + +namespace py = pybind11; + +void regclass_pyngraph_Output(py::module m) +{ + py::class_, std::shared_ptr>> output( + m, "Output", py::dynamic_attr()); + output.doc() = "ngraph.impl.Output wraps ngraph::Output"; + + output.def("get_node", &ngraph::Output::get_node); + output.def("get_index", &ngraph::Output::get_index); + output.def("get_element_type", &ngraph::Output::get_element_type); + output.def("get_shape", &ngraph::Output::get_shape); + output.def("get_partial_shape", &ngraph::Output::get_partial_shape); + output.def("get_target_inputs", &ngraph::Output::get_target_inputs); +} diff --git a/ngraph/python/src/pyngraph/node_output.hpp b/ngraph/python/src/pyngraph/node_output.hpp new file mode 100644 index 00000000000..334a8b457f7 --- /dev/null +++ b/ngraph/python/src/pyngraph/node_output.hpp @@ -0,0 +1,23 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include + +namespace py = pybind11; + +void regclass_pyngraph_Output(py::module m); diff --git a/ngraph/python/src/pyngraph/pyngraph.cpp b/ngraph/python/src/pyngraph/pyngraph.cpp index ad0316ccd0e..1d0202cb41b 100644 --- a/ngraph/python/src/pyngraph/pyngraph.cpp +++ b/ngraph/python/src/pyngraph/pyngraph.cpp @@ -22,6 +22,8 @@ #include "pyngraph/function.hpp" #include "pyngraph/node.hpp" #include "pyngraph/node_factory.hpp" +#include "pyngraph/node_input.hpp" +#include "pyngraph/node_output.hpp" #if defined(NGRAPH_ONNX_IMPORT_ENABLE) #include "pyngraph/onnx_import/onnx_import.hpp" #endif @@ -43,6 +45,8 @@ PYBIND11_MODULE(_pyngraph, m) { m.doc() = "Package ngraph.impl that wraps nGraph's namespace ngraph"; regclass_pyngraph_Node(m); + regclass_pyngraph_Input(m); + regclass_pyngraph_Output(m); regclass_pyngraph_NodeFactory(m); regclass_pyngraph_Dimension(m); // Dimension must be registered before PartialShape regclass_pyngraph_PartialShape(m); diff --git a/ngraph/python/tests/test_ngraph/test_basic.py b/ngraph/python/tests/test_ngraph/test_basic.py index c1caa3e59fb..661234c2ed6 100644 --- a/ngraph/python/tests/test_ngraph/test_basic.py +++ b/ngraph/python/tests/test_ngraph/test_basic.py @@ -20,7 +20,7 @@ import pytest import ngraph as ng from ngraph.exceptions import UserInputError -from ngraph.impl import Function +from ngraph.impl import Function, PartialShape, Shape from tests.runtime import get_runtime from tests.test_ngraph.util import run_op_node @@ -271,3 +271,89 @@ def test_result(): result = run_op_node([node], ng.ops.result) assert np.allclose(result, node) + + +def test_node_output(): + input_array = np.array([0, 1, 2, 3, 4, 5]) + splits = 3 + expected_shape = len(input_array) // splits + + input_tensor = ng.constant(input_array, dtype=np.int32) + axis = ng.constant(0, dtype=np.int64) + split_node = ng.split(input_tensor, axis, splits) + + split_node_outputs = split_node.outputs() + + assert len(split_node_outputs) == splits + assert [output_node.get_index() for output_node in split_node_outputs] == [0, 1, 2] + assert np.equal( + [output_node.get_element_type() for output_node in split_node_outputs], + input_tensor.get_element_type(), + ).all() + assert np.equal( + [output_node.get_shape() for output_node in split_node_outputs], + Shape([expected_shape]), + ).all() + assert np.equal( + [output_node.get_partial_shape() for output_node in split_node_outputs], + PartialShape([expected_shape]), + ).all() + + output0 = split_node.output(0) + output1 = split_node.output(1) + output2 = split_node.output(2) + + assert [output0.get_index(), output1.get_index(), output2.get_index()] == [0, 1, 2] + + +def test_node_input(): + shape = [2, 2] + parameter_a = ng.parameter(shape, dtype=np.float32, name="A") + parameter_b = ng.parameter(shape, dtype=np.float32, name="B") + + model = parameter_a + parameter_b + + model_inputs = model.inputs() + + assert len(model_inputs) == 2 + assert [input_node.get_index() for input_node in model_inputs] == [0, 1] + assert np.equal( + [input_node.get_element_type() for input_node in model_inputs], + model.get_element_type(), + ).all() + assert np.equal( + [input_node.get_shape() for input_node in model_inputs], Shape(shape) + ).all() + assert np.equal( + [input_node.get_partial_shape() for input_node in model_inputs], + PartialShape(shape), + ).all() + + input0 = model.input(0) + input1 = model.input(1) + + assert [input0.get_index(), input1.get_index()] == [0, 1] + + +def test_node_target_inputs_soruce_output(): + shape = [2, 2] + parameter_a = ng.parameter(shape, dtype=np.float32, name="A") + parameter_b = ng.parameter(shape, dtype=np.float32, name="B") + + model = parameter_a + parameter_b + + out_a = list(parameter_a.output(0).get_target_inputs())[0] + out_b = list(parameter_b.output(0).get_target_inputs())[0] + + assert out_a.get_node().name == model.name + assert out_b.get_node().name == model.name + assert np.equal([out_a.get_shape()], [model.get_output_shape(0)]).all() + assert np.equal([out_b.get_shape()], [model.get_output_shape(0)]).all() + + in_model0 = model.input(0).get_source_output() + in_model1 = model.input(1).get_source_output() + + assert in_model0.get_node().name == parameter_a.name + assert in_model1.get_node().name == parameter_b.name + assert np.equal([in_model0.get_shape()], [model.get_output_shape(0)]).all() + assert np.equal([in_model1.get_shape()], [model.get_output_shape(0)]).all()