Extend nGraph for operation GatherND-5 and implement reference (#2587)

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev
2020-10-14 12:20:22 +03:00
committed by GitHub
parent 6d72110365
commit 9956639531
17 changed files with 1245 additions and 307 deletions

View File

@@ -17,6 +17,7 @@ import numpy as np
import pytest
import ngraph as ng
from ngraph.impl import Type
from tests.runtime import get_runtime
from tests.test_ngraph.util import run_op_node
@@ -199,3 +200,18 @@ def test_select():
result = run_op_node([cond, then_node, else_node], ng.select)
assert np.allclose(result, excepted)
def test_gather_nd():
indices_type = np.int32
data_dtype = np.float32
data = ng.parameter([2, 10, 80, 30, 50], dtype=data_dtype, name="data")
indices = ng.parameter([2, 10, 30, 40, 2], dtype=indices_type, name="indices")
batch_dims = 2
expected_shape = [20, 30, 40, 50]
node = ng.gather_nd(data, indices, batch_dims)
assert node.get_type_name() == "GatherND"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == expected_shape
assert node.get_output_element_type(0) == Type.f32