Add Input and Output class to Py API (#1284)

This commit is contained in:
Jan Iwaszkiewicz 2020-07-15 15:32:24 +02:00 committed by GitHub
parent 173ce2c907
commit db09547087
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 223 additions and 1 deletions

View File

@ -185,6 +185,8 @@ sources = [
"pyngraph/dimension.cpp",
"pyngraph/function.cpp",
"pyngraph/node.cpp",
"pyngraph/node_input.cpp",
"pyngraph/node_output.cpp",
"pyngraph/node_factory.cpp",
"pyngraph/ops/constant.cpp",
"pyngraph/ops/get_output_element.cpp",

View File

@ -34,6 +34,8 @@ if sys.platform == "win32":
from _pyngraph import Dimension
from _pyngraph import Function
from _pyngraph import Input
from _pyngraph import Output
from _pyngraph import Node
from _pyngraph import Type
from _pyngraph import PartialShape

View File

@ -77,6 +77,14 @@ void regclass_pyngraph_Node(py::module m)
node.def("get_output_partial_shape", &ngraph::Node::get_output_partial_shape);
node.def("get_type_name", &ngraph::Node::get_type_name);
node.def("get_unique_name", &ngraph::Node::get_name);
node.def("input", (ngraph::Input<ngraph::Node>(ngraph::Node::*)(size_t)) & ngraph::Node::input);
node.def("inputs",
(std::vector<ngraph::Input<ngraph::Node>>(ngraph::Node::*)()) & ngraph::Node::inputs);
node.def("output",
(ngraph::Output<ngraph::Node>(ngraph::Node::*)(size_t)) & ngraph::Node::output);
node.def("outputs",
(std::vector<ngraph::Output<ngraph::Node>>(ngraph::Node::*)()) &
ngraph::Node::outputs);
node.def_property("name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name);
node.def_property_readonly("shape", &ngraph::Node::get_shape);

View File

@ -0,0 +1,37 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/stl.h>
#include "dict_attribute_visitor.hpp"
#include "ngraph/node_input.hpp"
#include "pyngraph/node_input.hpp"
namespace py = pybind11;
void regclass_pyngraph_Input(py::module m)
{
py::class_<ngraph::Input<ngraph::Node>, std::shared_ptr<ngraph::Input<ngraph::Node>>> input(
m, "Input", py::dynamic_attr());
input.doc() = "ngraph.impl.Input wraps ngraph::Input<Node>";
input.def("get_node", &ngraph::Input<ngraph::Node>::get_node);
input.def("get_index", &ngraph::Input<ngraph::Node>::get_index);
input.def("get_element_type", &ngraph::Input<ngraph::Node>::get_element_type);
input.def("get_shape", &ngraph::Input<ngraph::Node>::get_shape);
input.def("get_partial_shape", &ngraph::Input<ngraph::Node>::get_partial_shape);
input.def("get_source_output", &ngraph::Input<ngraph::Node>::get_source_output);
}

View File

@ -0,0 +1,23 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_Input(py::module m);

View File

@ -0,0 +1,37 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/stl.h>
#include "dict_attribute_visitor.hpp"
#include "ngraph/node_output.hpp"
#include "pyngraph/node_output.hpp"
namespace py = pybind11;
void regclass_pyngraph_Output(py::module m)
{
py::class_<ngraph::Output<ngraph::Node>, std::shared_ptr<ngraph::Output<ngraph::Node>>> output(
m, "Output", py::dynamic_attr());
output.doc() = "ngraph.impl.Output wraps ngraph::Output<Node>";
output.def("get_node", &ngraph::Output<ngraph::Node>::get_node);
output.def("get_index", &ngraph::Output<ngraph::Node>::get_index);
output.def("get_element_type", &ngraph::Output<ngraph::Node>::get_element_type);
output.def("get_shape", &ngraph::Output<ngraph::Node>::get_shape);
output.def("get_partial_shape", &ngraph::Output<ngraph::Node>::get_partial_shape);
output.def("get_target_inputs", &ngraph::Output<ngraph::Node>::get_target_inputs);
}

View File

@ -0,0 +1,23 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_Output(py::module m);

View File

@ -22,6 +22,8 @@
#include "pyngraph/function.hpp"
#include "pyngraph/node.hpp"
#include "pyngraph/node_factory.hpp"
#include "pyngraph/node_input.hpp"
#include "pyngraph/node_output.hpp"
#if defined(NGRAPH_ONNX_IMPORT_ENABLE)
#include "pyngraph/onnx_import/onnx_import.hpp"
#endif
@ -43,6 +45,8 @@ PYBIND11_MODULE(_pyngraph, m)
{
m.doc() = "Package ngraph.impl that wraps nGraph's namespace ngraph";
regclass_pyngraph_Node(m);
regclass_pyngraph_Input(m);
regclass_pyngraph_Output(m);
regclass_pyngraph_NodeFactory(m);
regclass_pyngraph_Dimension(m); // Dimension must be registered before PartialShape
regclass_pyngraph_PartialShape(m);

View File

@ -20,7 +20,7 @@ import pytest
import ngraph as ng
from ngraph.exceptions import UserInputError
from ngraph.impl import Function
from ngraph.impl import Function, PartialShape, Shape
from tests.runtime import get_runtime
from tests.test_ngraph.util import run_op_node
@ -271,3 +271,89 @@ def test_result():
result = run_op_node([node], ng.ops.result)
assert np.allclose(result, node)
def test_node_output():
input_array = np.array([0, 1, 2, 3, 4, 5])
splits = 3
expected_shape = len(input_array) // splits
input_tensor = ng.constant(input_array, dtype=np.int32)
axis = ng.constant(0, dtype=np.int64)
split_node = ng.split(input_tensor, axis, splits)
split_node_outputs = split_node.outputs()
assert len(split_node_outputs) == splits
assert [output_node.get_index() for output_node in split_node_outputs] == [0, 1, 2]
assert np.equal(
[output_node.get_element_type() for output_node in split_node_outputs],
input_tensor.get_element_type(),
).all()
assert np.equal(
[output_node.get_shape() for output_node in split_node_outputs],
Shape([expected_shape]),
).all()
assert np.equal(
[output_node.get_partial_shape() for output_node in split_node_outputs],
PartialShape([expected_shape]),
).all()
output0 = split_node.output(0)
output1 = split_node.output(1)
output2 = split_node.output(2)
assert [output0.get_index(), output1.get_index(), output2.get_index()] == [0, 1, 2]
def test_node_input():
shape = [2, 2]
parameter_a = ng.parameter(shape, dtype=np.float32, name="A")
parameter_b = ng.parameter(shape, dtype=np.float32, name="B")
model = parameter_a + parameter_b
model_inputs = model.inputs()
assert len(model_inputs) == 2
assert [input_node.get_index() for input_node in model_inputs] == [0, 1]
assert np.equal(
[input_node.get_element_type() for input_node in model_inputs],
model.get_element_type(),
).all()
assert np.equal(
[input_node.get_shape() for input_node in model_inputs], Shape(shape)
).all()
assert np.equal(
[input_node.get_partial_shape() for input_node in model_inputs],
PartialShape(shape),
).all()
input0 = model.input(0)
input1 = model.input(1)
assert [input0.get_index(), input1.get_index()] == [0, 1]
def test_node_target_inputs_soruce_output():
shape = [2, 2]
parameter_a = ng.parameter(shape, dtype=np.float32, name="A")
parameter_b = ng.parameter(shape, dtype=np.float32, name="B")
model = parameter_a + parameter_b
out_a = list(parameter_a.output(0).get_target_inputs())[0]
out_b = list(parameter_b.output(0).get_target_inputs())[0]
assert out_a.get_node().name == model.name
assert out_b.get_node().name == model.name
assert np.equal([out_a.get_shape()], [model.get_output_shape(0)]).all()
assert np.equal([out_b.get_shape()], [model.get_output_shape(0)]).all()
in_model0 = model.input(0).get_source_output()
in_model1 = model.input(1).get_source_output()
assert in_model0.get_node().name == parameter_a.name
assert in_model1.get_node().name == parameter_b.name
assert np.equal([in_model0.get_shape()], [model.get_output_shape(0)]).all()
assert np.equal([in_model1.get_shape()], [model.get_output_shape(0)]).all()