Add Input and Output class to Py API (#1284)
This commit is contained in:
@@ -20,7 +20,7 @@ import pytest
|
||||
|
||||
import ngraph as ng
|
||||
from ngraph.exceptions import UserInputError
|
||||
from ngraph.impl import Function
|
||||
from ngraph.impl import Function, PartialShape, Shape
|
||||
from tests.runtime import get_runtime
|
||||
from tests.test_ngraph.util import run_op_node
|
||||
|
||||
@@ -271,3 +271,89 @@ def test_result():
|
||||
|
||||
result = run_op_node([node], ng.ops.result)
|
||||
assert np.allclose(result, node)
|
||||
|
||||
|
||||
def test_node_output():
|
||||
input_array = np.array([0, 1, 2, 3, 4, 5])
|
||||
splits = 3
|
||||
expected_shape = len(input_array) // splits
|
||||
|
||||
input_tensor = ng.constant(input_array, dtype=np.int32)
|
||||
axis = ng.constant(0, dtype=np.int64)
|
||||
split_node = ng.split(input_tensor, axis, splits)
|
||||
|
||||
split_node_outputs = split_node.outputs()
|
||||
|
||||
assert len(split_node_outputs) == splits
|
||||
assert [output_node.get_index() for output_node in split_node_outputs] == [0, 1, 2]
|
||||
assert np.equal(
|
||||
[output_node.get_element_type() for output_node in split_node_outputs],
|
||||
input_tensor.get_element_type(),
|
||||
).all()
|
||||
assert np.equal(
|
||||
[output_node.get_shape() for output_node in split_node_outputs],
|
||||
Shape([expected_shape]),
|
||||
).all()
|
||||
assert np.equal(
|
||||
[output_node.get_partial_shape() for output_node in split_node_outputs],
|
||||
PartialShape([expected_shape]),
|
||||
).all()
|
||||
|
||||
output0 = split_node.output(0)
|
||||
output1 = split_node.output(1)
|
||||
output2 = split_node.output(2)
|
||||
|
||||
assert [output0.get_index(), output1.get_index(), output2.get_index()] == [0, 1, 2]
|
||||
|
||||
|
||||
def test_node_input():
|
||||
shape = [2, 2]
|
||||
parameter_a = ng.parameter(shape, dtype=np.float32, name="A")
|
||||
parameter_b = ng.parameter(shape, dtype=np.float32, name="B")
|
||||
|
||||
model = parameter_a + parameter_b
|
||||
|
||||
model_inputs = model.inputs()
|
||||
|
||||
assert len(model_inputs) == 2
|
||||
assert [input_node.get_index() for input_node in model_inputs] == [0, 1]
|
||||
assert np.equal(
|
||||
[input_node.get_element_type() for input_node in model_inputs],
|
||||
model.get_element_type(),
|
||||
).all()
|
||||
assert np.equal(
|
||||
[input_node.get_shape() for input_node in model_inputs], Shape(shape)
|
||||
).all()
|
||||
assert np.equal(
|
||||
[input_node.get_partial_shape() for input_node in model_inputs],
|
||||
PartialShape(shape),
|
||||
).all()
|
||||
|
||||
input0 = model.input(0)
|
||||
input1 = model.input(1)
|
||||
|
||||
assert [input0.get_index(), input1.get_index()] == [0, 1]
|
||||
|
||||
|
||||
def test_node_target_inputs_soruce_output():
|
||||
shape = [2, 2]
|
||||
parameter_a = ng.parameter(shape, dtype=np.float32, name="A")
|
||||
parameter_b = ng.parameter(shape, dtype=np.float32, name="B")
|
||||
|
||||
model = parameter_a + parameter_b
|
||||
|
||||
out_a = list(parameter_a.output(0).get_target_inputs())[0]
|
||||
out_b = list(parameter_b.output(0).get_target_inputs())[0]
|
||||
|
||||
assert out_a.get_node().name == model.name
|
||||
assert out_b.get_node().name == model.name
|
||||
assert np.equal([out_a.get_shape()], [model.get_output_shape(0)]).all()
|
||||
assert np.equal([out_b.get_shape()], [model.get_output_shape(0)]).all()
|
||||
|
||||
in_model0 = model.input(0).get_source_output()
|
||||
in_model1 = model.input(1).get_source_output()
|
||||
|
||||
assert in_model0.get_node().name == parameter_a.name
|
||||
assert in_model1.get_node().name == parameter_b.name
|
||||
assert np.equal([in_model0.get_shape()], [model.get_output_shape(0)]).all()
|
||||
assert np.equal([in_model1.get_shape()], [model.get_output_shape(0)]).all()
|
||||
|
||||
Reference in New Issue
Block a user