Extend nGraph for operation GatherND-5 and implement reference (#2587)
Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
@@ -17,6 +17,7 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
import ngraph as ng
|
||||
from ngraph.impl import Type
|
||||
from tests.runtime import get_runtime
|
||||
from tests.test_ngraph.util import run_op_node
|
||||
|
||||
@@ -199,3 +200,18 @@ def test_select():
|
||||
|
||||
result = run_op_node([cond, then_node, else_node], ng.select)
|
||||
assert np.allclose(result, excepted)
|
||||
|
||||
|
||||
def test_gather_nd():
|
||||
indices_type = np.int32
|
||||
data_dtype = np.float32
|
||||
data = ng.parameter([2, 10, 80, 30, 50], dtype=data_dtype, name="data")
|
||||
indices = ng.parameter([2, 10, 30, 40, 2], dtype=indices_type, name="indices")
|
||||
batch_dims = 2
|
||||
expected_shape = [20, 30, 40, 50]
|
||||
|
||||
node = ng.gather_nd(data, indices, batch_dims)
|
||||
assert node.get_type_name() == "GatherND"
|
||||
assert node.get_output_size() == 1
|
||||
assert list(node.get_output_shape(0)) == expected_shape
|
||||
assert node.get_output_element_type(0) == Type.f32
|
||||
|
||||
Reference in New Issue
Block a user