Add Input and Output class to Py API (#1284)

This commit is contained in:
Jan Iwaszkiewicz
2020-07-15 15:32:24 +02:00
committed by GitHub
parent 173ce2c907
commit db09547087
9 changed files with 223 additions and 1 deletions

View File

@@ -20,7 +20,7 @@ import pytest
import ngraph as ng
from ngraph.exceptions import UserInputError
from ngraph.impl import Function
from ngraph.impl import Function, PartialShape, Shape
from tests.runtime import get_runtime
from tests.test_ngraph.util import run_op_node
@@ -271,3 +271,89 @@ def test_result():
result = run_op_node([node], ng.ops.result)
assert np.allclose(result, node)
def test_node_output():
input_array = np.array([0, 1, 2, 3, 4, 5])
splits = 3
expected_shape = len(input_array) // splits
input_tensor = ng.constant(input_array, dtype=np.int32)
axis = ng.constant(0, dtype=np.int64)
split_node = ng.split(input_tensor, axis, splits)
split_node_outputs = split_node.outputs()
assert len(split_node_outputs) == splits
assert [output_node.get_index() for output_node in split_node_outputs] == [0, 1, 2]
assert np.equal(
[output_node.get_element_type() for output_node in split_node_outputs],
input_tensor.get_element_type(),
).all()
assert np.equal(
[output_node.get_shape() for output_node in split_node_outputs],
Shape([expected_shape]),
).all()
assert np.equal(
[output_node.get_partial_shape() for output_node in split_node_outputs],
PartialShape([expected_shape]),
).all()
output0 = split_node.output(0)
output1 = split_node.output(1)
output2 = split_node.output(2)
assert [output0.get_index(), output1.get_index(), output2.get_index()] == [0, 1, 2]
def test_node_input():
shape = [2, 2]
parameter_a = ng.parameter(shape, dtype=np.float32, name="A")
parameter_b = ng.parameter(shape, dtype=np.float32, name="B")
model = parameter_a + parameter_b
model_inputs = model.inputs()
assert len(model_inputs) == 2
assert [input_node.get_index() for input_node in model_inputs] == [0, 1]
assert np.equal(
[input_node.get_element_type() for input_node in model_inputs],
model.get_element_type(),
).all()
assert np.equal(
[input_node.get_shape() for input_node in model_inputs], Shape(shape)
).all()
assert np.equal(
[input_node.get_partial_shape() for input_node in model_inputs],
PartialShape(shape),
).all()
input0 = model.input(0)
input1 = model.input(1)
assert [input0.get_index(), input1.get_index()] == [0, 1]
def test_node_target_inputs_soruce_output():
shape = [2, 2]
parameter_a = ng.parameter(shape, dtype=np.float32, name="A")
parameter_b = ng.parameter(shape, dtype=np.float32, name="B")
model = parameter_a + parameter_b
out_a = list(parameter_a.output(0).get_target_inputs())[0]
out_b = list(parameter_b.output(0).get_target_inputs())[0]
assert out_a.get_node().name == model.name
assert out_b.get_node().name == model.name
assert np.equal([out_a.get_shape()], [model.get_output_shape(0)]).all()
assert np.equal([out_b.get_shape()], [model.get_output_shape(0)]).all()
in_model0 = model.input(0).get_source_output()
in_model1 = model.input(1).get_source_output()
assert in_model0.get_node().name == parameter_a.name
assert in_model1.get_node().name == parameter_b.name
assert np.equal([in_model0.get_shape()], [model.get_output_shape(0)]).all()
assert np.equal([in_model1.get_shape()], [model.get_output_shape(0)]).all()