Add Input and Output class to Py API (#1284)
This commit is contained in:
parent
173ce2c907
commit
db09547087
@ -185,6 +185,8 @@ sources = [
|
|||||||
"pyngraph/dimension.cpp",
|
"pyngraph/dimension.cpp",
|
||||||
"pyngraph/function.cpp",
|
"pyngraph/function.cpp",
|
||||||
"pyngraph/node.cpp",
|
"pyngraph/node.cpp",
|
||||||
|
"pyngraph/node_input.cpp",
|
||||||
|
"pyngraph/node_output.cpp",
|
||||||
"pyngraph/node_factory.cpp",
|
"pyngraph/node_factory.cpp",
|
||||||
"pyngraph/ops/constant.cpp",
|
"pyngraph/ops/constant.cpp",
|
||||||
"pyngraph/ops/get_output_element.cpp",
|
"pyngraph/ops/get_output_element.cpp",
|
||||||
|
@ -34,6 +34,8 @@ if sys.platform == "win32":
|
|||||||
|
|
||||||
from _pyngraph import Dimension
|
from _pyngraph import Dimension
|
||||||
from _pyngraph import Function
|
from _pyngraph import Function
|
||||||
|
from _pyngraph import Input
|
||||||
|
from _pyngraph import Output
|
||||||
from _pyngraph import Node
|
from _pyngraph import Node
|
||||||
from _pyngraph import Type
|
from _pyngraph import Type
|
||||||
from _pyngraph import PartialShape
|
from _pyngraph import PartialShape
|
||||||
|
@ -77,6 +77,14 @@ void regclass_pyngraph_Node(py::module m)
|
|||||||
node.def("get_output_partial_shape", &ngraph::Node::get_output_partial_shape);
|
node.def("get_output_partial_shape", &ngraph::Node::get_output_partial_shape);
|
||||||
node.def("get_type_name", &ngraph::Node::get_type_name);
|
node.def("get_type_name", &ngraph::Node::get_type_name);
|
||||||
node.def("get_unique_name", &ngraph::Node::get_name);
|
node.def("get_unique_name", &ngraph::Node::get_name);
|
||||||
|
node.def("input", (ngraph::Input<ngraph::Node>(ngraph::Node::*)(size_t)) & ngraph::Node::input);
|
||||||
|
node.def("inputs",
|
||||||
|
(std::vector<ngraph::Input<ngraph::Node>>(ngraph::Node::*)()) & ngraph::Node::inputs);
|
||||||
|
node.def("output",
|
||||||
|
(ngraph::Output<ngraph::Node>(ngraph::Node::*)(size_t)) & ngraph::Node::output);
|
||||||
|
node.def("outputs",
|
||||||
|
(std::vector<ngraph::Output<ngraph::Node>>(ngraph::Node::*)()) &
|
||||||
|
ngraph::Node::outputs);
|
||||||
|
|
||||||
node.def_property("name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name);
|
node.def_property("name", &ngraph::Node::get_friendly_name, &ngraph::Node::set_friendly_name);
|
||||||
node.def_property_readonly("shape", &ngraph::Node::get_shape);
|
node.def_property_readonly("shape", &ngraph::Node::get_shape);
|
||||||
|
37
ngraph/python/src/pyngraph/node_input.cpp
Normal file
37
ngraph/python/src/pyngraph/node_input.cpp
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
//*****************************************************************************
|
||||||
|
// Copyright 2017-2020 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//*****************************************************************************
|
||||||
|
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
|
#include "dict_attribute_visitor.hpp"
|
||||||
|
#include "ngraph/node_input.hpp"
|
||||||
|
#include "pyngraph/node_input.hpp"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
void regclass_pyngraph_Input(py::module m)
|
||||||
|
{
|
||||||
|
py::class_<ngraph::Input<ngraph::Node>, std::shared_ptr<ngraph::Input<ngraph::Node>>> input(
|
||||||
|
m, "Input", py::dynamic_attr());
|
||||||
|
input.doc() = "ngraph.impl.Input wraps ngraph::Input<Node>";
|
||||||
|
|
||||||
|
input.def("get_node", &ngraph::Input<ngraph::Node>::get_node);
|
||||||
|
input.def("get_index", &ngraph::Input<ngraph::Node>::get_index);
|
||||||
|
input.def("get_element_type", &ngraph::Input<ngraph::Node>::get_element_type);
|
||||||
|
input.def("get_shape", &ngraph::Input<ngraph::Node>::get_shape);
|
||||||
|
input.def("get_partial_shape", &ngraph::Input<ngraph::Node>::get_partial_shape);
|
||||||
|
input.def("get_source_output", &ngraph::Input<ngraph::Node>::get_source_output);
|
||||||
|
}
|
23
ngraph/python/src/pyngraph/node_input.hpp
Normal file
23
ngraph/python/src/pyngraph/node_input.hpp
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
//*****************************************************************************
|
||||||
|
// Copyright 2017-2020 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//*****************************************************************************
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
void regclass_pyngraph_Input(py::module m);
|
37
ngraph/python/src/pyngraph/node_output.cpp
Normal file
37
ngraph/python/src/pyngraph/node_output.cpp
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
//*****************************************************************************
|
||||||
|
// Copyright 2017-2020 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//*****************************************************************************
|
||||||
|
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
|
#include "dict_attribute_visitor.hpp"
|
||||||
|
#include "ngraph/node_output.hpp"
|
||||||
|
#include "pyngraph/node_output.hpp"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
void regclass_pyngraph_Output(py::module m)
|
||||||
|
{
|
||||||
|
py::class_<ngraph::Output<ngraph::Node>, std::shared_ptr<ngraph::Output<ngraph::Node>>> output(
|
||||||
|
m, "Output", py::dynamic_attr());
|
||||||
|
output.doc() = "ngraph.impl.Output wraps ngraph::Output<Node>";
|
||||||
|
|
||||||
|
output.def("get_node", &ngraph::Output<ngraph::Node>::get_node);
|
||||||
|
output.def("get_index", &ngraph::Output<ngraph::Node>::get_index);
|
||||||
|
output.def("get_element_type", &ngraph::Output<ngraph::Node>::get_element_type);
|
||||||
|
output.def("get_shape", &ngraph::Output<ngraph::Node>::get_shape);
|
||||||
|
output.def("get_partial_shape", &ngraph::Output<ngraph::Node>::get_partial_shape);
|
||||||
|
output.def("get_target_inputs", &ngraph::Output<ngraph::Node>::get_target_inputs);
|
||||||
|
}
|
23
ngraph/python/src/pyngraph/node_output.hpp
Normal file
23
ngraph/python/src/pyngraph/node_output.hpp
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
//*****************************************************************************
|
||||||
|
// Copyright 2017-2020 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//*****************************************************************************
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
void regclass_pyngraph_Output(py::module m);
|
@ -22,6 +22,8 @@
|
|||||||
#include "pyngraph/function.hpp"
|
#include "pyngraph/function.hpp"
|
||||||
#include "pyngraph/node.hpp"
|
#include "pyngraph/node.hpp"
|
||||||
#include "pyngraph/node_factory.hpp"
|
#include "pyngraph/node_factory.hpp"
|
||||||
|
#include "pyngraph/node_input.hpp"
|
||||||
|
#include "pyngraph/node_output.hpp"
|
||||||
#if defined(NGRAPH_ONNX_IMPORT_ENABLE)
|
#if defined(NGRAPH_ONNX_IMPORT_ENABLE)
|
||||||
#include "pyngraph/onnx_import/onnx_import.hpp"
|
#include "pyngraph/onnx_import/onnx_import.hpp"
|
||||||
#endif
|
#endif
|
||||||
@ -43,6 +45,8 @@ PYBIND11_MODULE(_pyngraph, m)
|
|||||||
{
|
{
|
||||||
m.doc() = "Package ngraph.impl that wraps nGraph's namespace ngraph";
|
m.doc() = "Package ngraph.impl that wraps nGraph's namespace ngraph";
|
||||||
regclass_pyngraph_Node(m);
|
regclass_pyngraph_Node(m);
|
||||||
|
regclass_pyngraph_Input(m);
|
||||||
|
regclass_pyngraph_Output(m);
|
||||||
regclass_pyngraph_NodeFactory(m);
|
regclass_pyngraph_NodeFactory(m);
|
||||||
regclass_pyngraph_Dimension(m); // Dimension must be registered before PartialShape
|
regclass_pyngraph_Dimension(m); // Dimension must be registered before PartialShape
|
||||||
regclass_pyngraph_PartialShape(m);
|
regclass_pyngraph_PartialShape(m);
|
||||||
|
@ -20,7 +20,7 @@ import pytest
|
|||||||
|
|
||||||
import ngraph as ng
|
import ngraph as ng
|
||||||
from ngraph.exceptions import UserInputError
|
from ngraph.exceptions import UserInputError
|
||||||
from ngraph.impl import Function
|
from ngraph.impl import Function, PartialShape, Shape
|
||||||
from tests.runtime import get_runtime
|
from tests.runtime import get_runtime
|
||||||
from tests.test_ngraph.util import run_op_node
|
from tests.test_ngraph.util import run_op_node
|
||||||
|
|
||||||
@ -271,3 +271,89 @@ def test_result():
|
|||||||
|
|
||||||
result = run_op_node([node], ng.ops.result)
|
result = run_op_node([node], ng.ops.result)
|
||||||
assert np.allclose(result, node)
|
assert np.allclose(result, node)
|
||||||
|
|
||||||
|
|
||||||
|
def test_node_output():
|
||||||
|
input_array = np.array([0, 1, 2, 3, 4, 5])
|
||||||
|
splits = 3
|
||||||
|
expected_shape = len(input_array) // splits
|
||||||
|
|
||||||
|
input_tensor = ng.constant(input_array, dtype=np.int32)
|
||||||
|
axis = ng.constant(0, dtype=np.int64)
|
||||||
|
split_node = ng.split(input_tensor, axis, splits)
|
||||||
|
|
||||||
|
split_node_outputs = split_node.outputs()
|
||||||
|
|
||||||
|
assert len(split_node_outputs) == splits
|
||||||
|
assert [output_node.get_index() for output_node in split_node_outputs] == [0, 1, 2]
|
||||||
|
assert np.equal(
|
||||||
|
[output_node.get_element_type() for output_node in split_node_outputs],
|
||||||
|
input_tensor.get_element_type(),
|
||||||
|
).all()
|
||||||
|
assert np.equal(
|
||||||
|
[output_node.get_shape() for output_node in split_node_outputs],
|
||||||
|
Shape([expected_shape]),
|
||||||
|
).all()
|
||||||
|
assert np.equal(
|
||||||
|
[output_node.get_partial_shape() for output_node in split_node_outputs],
|
||||||
|
PartialShape([expected_shape]),
|
||||||
|
).all()
|
||||||
|
|
||||||
|
output0 = split_node.output(0)
|
||||||
|
output1 = split_node.output(1)
|
||||||
|
output2 = split_node.output(2)
|
||||||
|
|
||||||
|
assert [output0.get_index(), output1.get_index(), output2.get_index()] == [0, 1, 2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_node_input():
|
||||||
|
shape = [2, 2]
|
||||||
|
parameter_a = ng.parameter(shape, dtype=np.float32, name="A")
|
||||||
|
parameter_b = ng.parameter(shape, dtype=np.float32, name="B")
|
||||||
|
|
||||||
|
model = parameter_a + parameter_b
|
||||||
|
|
||||||
|
model_inputs = model.inputs()
|
||||||
|
|
||||||
|
assert len(model_inputs) == 2
|
||||||
|
assert [input_node.get_index() for input_node in model_inputs] == [0, 1]
|
||||||
|
assert np.equal(
|
||||||
|
[input_node.get_element_type() for input_node in model_inputs],
|
||||||
|
model.get_element_type(),
|
||||||
|
).all()
|
||||||
|
assert np.equal(
|
||||||
|
[input_node.get_shape() for input_node in model_inputs], Shape(shape)
|
||||||
|
).all()
|
||||||
|
assert np.equal(
|
||||||
|
[input_node.get_partial_shape() for input_node in model_inputs],
|
||||||
|
PartialShape(shape),
|
||||||
|
).all()
|
||||||
|
|
||||||
|
input0 = model.input(0)
|
||||||
|
input1 = model.input(1)
|
||||||
|
|
||||||
|
assert [input0.get_index(), input1.get_index()] == [0, 1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_node_target_inputs_soruce_output():
|
||||||
|
shape = [2, 2]
|
||||||
|
parameter_a = ng.parameter(shape, dtype=np.float32, name="A")
|
||||||
|
parameter_b = ng.parameter(shape, dtype=np.float32, name="B")
|
||||||
|
|
||||||
|
model = parameter_a + parameter_b
|
||||||
|
|
||||||
|
out_a = list(parameter_a.output(0).get_target_inputs())[0]
|
||||||
|
out_b = list(parameter_b.output(0).get_target_inputs())[0]
|
||||||
|
|
||||||
|
assert out_a.get_node().name == model.name
|
||||||
|
assert out_b.get_node().name == model.name
|
||||||
|
assert np.equal([out_a.get_shape()], [model.get_output_shape(0)]).all()
|
||||||
|
assert np.equal([out_b.get_shape()], [model.get_output_shape(0)]).all()
|
||||||
|
|
||||||
|
in_model0 = model.input(0).get_source_output()
|
||||||
|
in_model1 = model.input(1).get_source_output()
|
||||||
|
|
||||||
|
assert in_model0.get_node().name == parameter_a.name
|
||||||
|
assert in_model1.get_node().name == parameter_b.name
|
||||||
|
assert np.equal([in_model0.get_shape()], [model.get_output_shape(0)]).all()
|
||||||
|
assert np.equal([in_model1.get_shape()], [model.get_output_shape(0)]).all()
|
||||||
|
Loading…
Reference in New Issue
Block a user