Add Input and Output class to Py API (#1284)

This commit is contained in:
Jan Iwaszkiewicz 2020-07-15 15:32:24 +02:00 committed by GitHub
parent 173ce2c907
commit db09547087
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 223 additions and 1 deletions

View File

@ -185,6 +185,8 @@ sources = [
"pyngraph/dimension.cpp", "pyngraph/dimension.cpp",
"pyngraph/function.cpp", "pyngraph/function.cpp",
"pyngraph/node.cpp", "pyngraph/node.cpp",
"pyngraph/node_input.cpp",
"pyngraph/node_output.cpp",
"pyngraph/node_factory.cpp", "pyngraph/node_factory.cpp",
"pyngraph/ops/constant.cpp", "pyngraph/ops/constant.cpp",
"pyngraph/ops/get_output_element.cpp", "pyngraph/ops/get_output_element.cpp",

View File

@ -34,6 +34,8 @@ if sys.platform == "win32":
from _pyngraph import Dimension from _pyngraph import Dimension
from _pyngraph import Function from _pyngraph import Function
from _pyngraph import Input
from _pyngraph import Output
from _pyngraph import Node from _pyngraph import Node
from _pyngraph import Type from _pyngraph import Type
from _pyngraph import PartialShape from _pyngraph import PartialShape

View File

@ -77,6 +77,14 @@ void regclass_pyngraph_Node(py::module m)
node.def("get_output_partial_shape", &ngraph::Node::get_output_partial_shape); node.def("get_output_partial_shape", &ngraph::Node::get_output_partial_shape);
node.def("get_type_name", &ngraph::Node::get_type_name); node.def("get_type_name", &ngraph::Node::get_type_name);
node.def("get_unique_name", &ngraph::Node::get_name); node.def("get_unique_name", &ngraph::Node::get_name);
node.def("input", (ngraph::Input<ngraph::Node>(ngraph::Node::*)(size_t)) & ngraph::Node::input);
node.def("inputs",
(std::vector<ngraph::Input<ngraph::Node>>(ngraph::Node::*)()) & ngraph::Node::inputs);
node.def("output",
(ngraph::Output<ngraph::Node>(ngraph::Node::*)(size_t)) & ngraph::Node::output);
node.def("outputs",
(std::vector<ngraph::Output<ngraph::Node>>(ngraph::Node::*)()) &
ngraph::Node::outputs);
node.def_property("name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name); node.def_property("name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name);
node.def_property_readonly("shape", &ngraph::Node::get_shape); node.def_property_readonly("shape", &ngraph::Node::get_shape);

View File

@ -0,0 +1,37 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/stl.h>
#include "dict_attribute_visitor.hpp"
#include "ngraph/node_input.hpp"
#include "pyngraph/node_input.hpp"
namespace py = pybind11;
void regclass_pyngraph_Input(py::module m)
{
py::class_<ngraph::Input<ngraph::Node>, std::shared_ptr<ngraph::Input<ngraph::Node>>> input(
m, "Input", py::dynamic_attr());
input.doc() = "ngraph.impl.Input wraps ngraph::Input<Node>";
input.def("get_node", &ngraph::Input<ngraph::Node>::get_node);
input.def("get_index", &ngraph::Input<ngraph::Node>::get_index);
input.def("get_element_type", &ngraph::Input<ngraph::Node>::get_element_type);
input.def("get_shape", &ngraph::Input<ngraph::Node>::get_shape);
input.def("get_partial_shape", &ngraph::Input<ngraph::Node>::get_partial_shape);
input.def("get_source_output", &ngraph::Input<ngraph::Node>::get_source_output);
}

View File

@ -0,0 +1,23 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_Input(py::module m);

View File

@ -0,0 +1,37 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/stl.h>
#include "dict_attribute_visitor.hpp"
#include "ngraph/node_output.hpp"
#include "pyngraph/node_output.hpp"
namespace py = pybind11;
void regclass_pyngraph_Output(py::module m)
{
py::class_<ngraph::Output<ngraph::Node>, std::shared_ptr<ngraph::Output<ngraph::Node>>> output(
m, "Output", py::dynamic_attr());
output.doc() = "ngraph.impl.Output wraps ngraph::Output<Node>";
output.def("get_node", &ngraph::Output<ngraph::Node>::get_node);
output.def("get_index", &ngraph::Output<ngraph::Node>::get_index);
output.def("get_element_type", &ngraph::Output<ngraph::Node>::get_element_type);
output.def("get_shape", &ngraph::Output<ngraph::Node>::get_shape);
output.def("get_partial_shape", &ngraph::Output<ngraph::Node>::get_partial_shape);
output.def("get_target_inputs", &ngraph::Output<ngraph::Node>::get_target_inputs);
}

View File

@ -0,0 +1,23 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_Output(py::module m);

View File

@ -22,6 +22,8 @@
#include "pyngraph/function.hpp" #include "pyngraph/function.hpp"
#include "pyngraph/node.hpp" #include "pyngraph/node.hpp"
#include "pyngraph/node_factory.hpp" #include "pyngraph/node_factory.hpp"
#include "pyngraph/node_input.hpp"
#include "pyngraph/node_output.hpp"
#if defined(NGRAPH_ONNX_IMPORT_ENABLE) #if defined(NGRAPH_ONNX_IMPORT_ENABLE)
#include "pyngraph/onnx_import/onnx_import.hpp" #include "pyngraph/onnx_import/onnx_import.hpp"
#endif #endif
@ -43,6 +45,8 @@ PYBIND11_MODULE(_pyngraph, m)
{ {
m.doc() = "Package ngraph.impl that wraps nGraph's namespace ngraph"; m.doc() = "Package ngraph.impl that wraps nGraph's namespace ngraph";
regclass_pyngraph_Node(m); regclass_pyngraph_Node(m);
regclass_pyngraph_Input(m);
regclass_pyngraph_Output(m);
regclass_pyngraph_NodeFactory(m); regclass_pyngraph_NodeFactory(m);
regclass_pyngraph_Dimension(m); // Dimension must be registered before PartialShape regclass_pyngraph_Dimension(m); // Dimension must be registered before PartialShape
regclass_pyngraph_PartialShape(m); regclass_pyngraph_PartialShape(m);

View File

@ -20,7 +20,7 @@ import pytest
import ngraph as ng import ngraph as ng
from ngraph.exceptions import UserInputError from ngraph.exceptions import UserInputError
from ngraph.impl import Function from ngraph.impl import Function, PartialShape, Shape
from tests.runtime import get_runtime from tests.runtime import get_runtime
from tests.test_ngraph.util import run_op_node from tests.test_ngraph.util import run_op_node
@ -271,3 +271,89 @@ def test_result():
result = run_op_node([node], ng.ops.result) result = run_op_node([node], ng.ops.result)
assert np.allclose(result, node) assert np.allclose(result, node)
def test_node_output():
input_array = np.array([0, 1, 2, 3, 4, 5])
splits = 3
expected_shape = len(input_array) // splits
input_tensor = ng.constant(input_array, dtype=np.int32)
axis = ng.constant(0, dtype=np.int64)
split_node = ng.split(input_tensor, axis, splits)
split_node_outputs = split_node.outputs()
assert len(split_node_outputs) == splits
assert [output_node.get_index() for output_node in split_node_outputs] == [0, 1, 2]
assert np.equal(
[output_node.get_element_type() for output_node in split_node_outputs],
input_tensor.get_element_type(),
).all()
assert np.equal(
[output_node.get_shape() for output_node in split_node_outputs],
Shape([expected_shape]),
).all()
assert np.equal(
[output_node.get_partial_shape() for output_node in split_node_outputs],
PartialShape([expected_shape]),
).all()
output0 = split_node.output(0)
output1 = split_node.output(1)
output2 = split_node.output(2)
assert [output0.get_index(), output1.get_index(), output2.get_index()] == [0, 1, 2]
def test_node_input():
shape = [2, 2]
parameter_a = ng.parameter(shape, dtype=np.float32, name="A")
parameter_b = ng.parameter(shape, dtype=np.float32, name="B")
model = parameter_a + parameter_b
model_inputs = model.inputs()
assert len(model_inputs) == 2
assert [input_node.get_index() for input_node in model_inputs] == [0, 1]
assert np.equal(
[input_node.get_element_type() for input_node in model_inputs],
model.get_element_type(),
).all()
assert np.equal(
[input_node.get_shape() for input_node in model_inputs], Shape(shape)
).all()
assert np.equal(
[input_node.get_partial_shape() for input_node in model_inputs],
PartialShape(shape),
).all()
input0 = model.input(0)
input1 = model.input(1)
assert [input0.get_index(), input1.get_index()] == [0, 1]
def test_node_target_inputs_soruce_output():
shape = [2, 2]
parameter_a = ng.parameter(shape, dtype=np.float32, name="A")
parameter_b = ng.parameter(shape, dtype=np.float32, name="B")
model = parameter_a + parameter_b
out_a = list(parameter_a.output(0).get_target_inputs())[0]
out_b = list(parameter_b.output(0).get_target_inputs())[0]
assert out_a.get_node().name == model.name
assert out_b.get_node().name == model.name
assert np.equal([out_a.get_shape()], [model.get_output_shape(0)]).all()
assert np.equal([out_b.get_shape()], [model.get_output_shape(0)]).all()
in_model0 = model.input(0).get_source_output()
in_model1 = model.input(1).get_source_output()
assert in_model0.get_node().name == parameter_a.name
assert in_model1.get_node().name == parameter_b.name
assert np.equal([in_model0.get_shape()], [model.get_output_shape(0)]).all()
assert np.equal([in_model1.get_shape()], [model.get_output_shape(0)]).all()