[MO|nGraph]GatherND_8 (#7743)
* Add GatherND_8 operation * Update shape infer function and tests * Initial commit for nGraph GatherND_8 operation * Add GatherNDBase class implementation * Fix base class errors * Add missrd header * Update base class * Update GatherND_8 implementation * Fix codestyle * Fix wrong rank * Implement tests for gatherND_8 shape inference function * fix codestyle * Add limitation to doc * Siplyfy check in shape inference * Add more test cases * Update shape inference function * Add more test cases to cover all case with dynamic input shapes * Update shape inference function * Refactor tests * Add visitor tests for gatherND_8 operation * Correct comment * Add additional check is shape inference function * Update shape inference implementation for gathernd operartion * Fix codestyle * Remove restriction for data is fully defined * Update shape inference functon * Fix missed check for nonetype * Remove redundant checks for batch_dims * Fix codestyle
This commit is contained in:
committed by
GitHub
parent
76994c6ec9
commit
c8e1c8e3eb
@@ -525,7 +525,7 @@ Some TensorFlow\* operations do not match to any Inference Engine layer, but are
|
||||
| GRU | |
|
||||
| Gather | |
|
||||
| GatherElements | Doesn't work with negative indices |
|
||||
| GatherND | |
|
||||
| GatherND | Doesn't work with negative indices |
|
||||
| GatherTree | |
|
||||
| Gemm | |
|
||||
| GlobalAveragePool | |
|
||||
|
||||
@@ -16,7 +16,7 @@ class GatherND(Op):
|
||||
mandatory_props = {
|
||||
'type': self.op,
|
||||
'op': self.op,
|
||||
'version': 'opset5',
|
||||
'version': 'opset8',
|
||||
'infer': self.infer,
|
||||
'in_ports_count': 2,
|
||||
'out_ports_count': 1,
|
||||
@@ -56,41 +56,55 @@ class GatherND(Op):
|
||||
assert len(indices_shape) > 0, "Indices must not be a scalar"
|
||||
assert (batch_dims + indices_shape[-1]) <= len(data_shape), \
|
||||
"Length of a tuple with indices must not exceed a rank of data tensor excluding batch dimensions"
|
||||
assert node['version'] in ['opset5', 'opset8'], 'Unsupported version of GatherND operation: {}, operation ' \
|
||||
'name : {}'.format(node['version'], node.soft_get('name'))
|
||||
|
||||
# compute output shape
|
||||
batch = []
|
||||
if batch_dims > 0:
|
||||
if is_fully_defined(data_shape[:batch_dims]):
|
||||
batch = [np.prod(data_shape[:batch_dims]).tolist()]
|
||||
else:
|
||||
batch = [dynamic_dimension_value]
|
||||
else:
|
||||
batch = []
|
||||
if node['version'] == 'opset5': # Support old version of gatherND shape inference
|
||||
if is_fully_defined(data_shape[:batch_dims]):
|
||||
batch = [np.prod(data_shape[:batch_dims]).tolist()]
|
||||
else:
|
||||
batch = [dynamic_dimension_value]
|
||||
elif node['version'] == 'opset8':
|
||||
for dim in range(batch_dims):
|
||||
assert compatible_dims(indices_shape[dim], data_shape[dim]),\
|
||||
"Batch dimensions in data.shape and indices.shape must be compatible"
|
||||
if is_fully_defined(indices_shape[:batch_dims]):
|
||||
batch = indices_shape[:batch_dims].tolist()
|
||||
elif is_fully_defined(data_shape[:batch_dims]):
|
||||
batch = data_shape[:batch_dims].tolist()
|
||||
else:
|
||||
for ind in range(batch_dims):
|
||||
if indices_shape[ind] != dynamic_dimension_value:
|
||||
batch.append(indices_shape[ind])
|
||||
elif data_shape[ind] != dynamic_dimension_value:
|
||||
batch.append(data_shape[ind])
|
||||
else:
|
||||
batch.append(dynamic_dimension_value)
|
||||
|
||||
slice_shape = list(data_shape[(batch_dims + indices_shape[-1]):])
|
||||
output_shape = batch + list(indices_shape[batch_dims:-1]) + slice_shape
|
||||
|
||||
output_shape = batch + list(indices_shape)[batch_dims:-1] + slice_shape
|
||||
node.out_port(0).data.set_shape(output_shape)
|
||||
|
||||
# compute output value if all input values are defined
|
||||
if is_fully_defined(indices_value) and is_fully_defined(data_value):
|
||||
output_value = np.zeros(output_shape, dtype=data_value.dtype)
|
||||
if batch_dims == 0:
|
||||
output_indices_range = int64_array(indices_shape[:-1])
|
||||
for output_index in np.ndindex(tuple(output_indices_range)):
|
||||
indices_tuple = indices_value[output_index]
|
||||
output_value[output_index] = data_value[tuple(indices_tuple.T)]
|
||||
else:
|
||||
batch_dims_range = int64_array(indices_shape[:batch_dims])
|
||||
for batch_indices in np.ndindex(tuple(batch_dims_range)):
|
||||
# compute batch index in output tensor
|
||||
batch_ind = 0
|
||||
num_elements = 1
|
||||
for ind in reversed(range(len(batch_dims_range))):
|
||||
batch_ind += batch_indices[ind] * num_elements
|
||||
num_elements *= batch_dims_range[ind]
|
||||
output_indices_range = int64_array(indices_shape[batch_dims:-1])
|
||||
for output_index in np.ndindex(tuple(output_indices_range)):
|
||||
tmp_ind = batch_indices + output_index
|
||||
indices_tuple = tuple(indices_value[tmp_ind].T)
|
||||
full_input_ind = batch_indices + indices_tuple
|
||||
full_output_ind = tuple(np.array([batch_ind]).T) + output_index
|
||||
output_value[full_output_ind] = data_value[full_input_ind]
|
||||
# compute output value if all input indices are defined
|
||||
if is_fully_defined(indices_value) and data_value is not None:
|
||||
batch_dims_size = 1
|
||||
|
||||
for i in range(batch_dims):
|
||||
batch_dims_size *= indices_shape[i]
|
||||
|
||||
output_data = []
|
||||
|
||||
reshaped_indices = indices_value.reshape(batch_dims_size, -1, indices_shape[-1])
|
||||
|
||||
reshaped_data = data_value.reshape((batch_dims_size,) + tuple((data_shape[batch_dims:])))
|
||||
|
||||
for batch_dim in range(reshaped_indices.shape[0]):
|
||||
for outer_dim in range(reshaped_indices.shape[1]):
|
||||
gather_index = tuple(reshaped_indices[batch_dim][outer_dim])
|
||||
output_data.append(reshaped_data[(batch_dim,) + gather_index])
|
||||
output_value = np.asarray(output_data, dtype=data_value.dtype).reshape(output_shape)
|
||||
node.out_port(0).data.set_value(output_value)
|
||||
|
||||
@@ -14,7 +14,7 @@ nodes_attributes = {'data': {'kind': 'op'},
|
||||
'data_data': {'shape': None, 'value': None, 'kind': 'data'},
|
||||
'indices': {'kind': 'op'},
|
||||
'indices_data': {'shape': None, 'value': None, 'kind': 'data'},
|
||||
'gathernd_node': {'op': 'GatherNDUpdate', 'kind': 'op', 'batch_dims': 0},
|
||||
'gathernd_node': {'op': 'GatherNDUpdate', 'kind': 'op', 'batch_dims': 0, 'version': 'opset8'},
|
||||
'output': {'shape': None, 'value': None, 'kind': 'data'}}
|
||||
|
||||
# graph 1
|
||||
@@ -25,17 +25,21 @@ edges = [('data', 'data_data', {'in': 0}),
|
||||
('gathernd_node', 'output', {'out': 0})]
|
||||
|
||||
# test data for partial infer: gather elements
|
||||
inputs1 = {'data_data': {'shape': int64_array([10, 40]), 'value': None},
|
||||
inputs = {'data_data': {'shape': int64_array([10, 40]), 'value': None},
|
||||
'indices_data': {'shape': int64_array([3, 2]), 'value': None}}
|
||||
|
||||
# test data for partial infer: gather slices
|
||||
inputs2 = {'data_data': {'shape': int64_array([10, 40, 30]), 'value': None},
|
||||
inputs1 = {'data_data': {'shape': int64_array([10, 40, 30]), 'value': None},
|
||||
'indices_data': {'shape': int64_array([3, 2]), 'value': None}}
|
||||
|
||||
# test data for partial infer: gather slices and batch_dims=2
|
||||
inputs3 = {'data_data': {'shape': int64_array([10, 40, 4, 9]), 'value': None},
|
||||
inputs2 = {'data_data': {'shape': int64_array([10, 40, 4, 9]), 'value': None},
|
||||
'indices_data': {'shape': int64_array([10, 40, 3, 5, 1]), 'value': None}}
|
||||
|
||||
# test data for partial infer: gather slices and batch_dims=3 and indices.shape[-1]=len(data.shape)-batch_dims
|
||||
inputs3 = {'data_data': {'shape': int64_array([1, 64, 64, 320]), 'value': None},
|
||||
'indices_data': {'shape': int64_array([1, 64, 64, 1, 1]), 'value': None}}
|
||||
|
||||
# test data for constant folding: gather elements, batch_dims = 0
|
||||
inputs4 = {'data_data': {'shape': int64_array([2, 2]), 'value': int64_array([[1, 2],
|
||||
[3, 4]])},
|
||||
@@ -110,6 +114,14 @@ output8 = int64_array([[3, 8, 6],
|
||||
inputs9 = {'data_data': {'shape': shape_array([dynamic_dimension_value, 40, 4, 9]), 'value': None},
|
||||
'indices_data': {'shape': shape_array([dynamic_dimension_value, 40, 3, 5, 1]), 'value': None}}
|
||||
|
||||
# test data for partial infer: gather slices and batch_dims=2
|
||||
inputs10 = {'data_data': {'shape': shape_array([40, dynamic_dimension_value, 4, 9]), 'value': None},
|
||||
'indices_data': {'shape': shape_array([40, dynamic_dimension_value, 3, 5, 1]), 'value': None}}
|
||||
|
||||
# test data for partial infer: gather slices and batch_dims=2
|
||||
inputs11 = {'data_data': {'shape': shape_array([dynamic_dimension_value, 40, 4, 9]), 'value': None},
|
||||
'indices_data': {'shape': shape_array([40, dynamic_dimension_value, 3, 5, 1]), 'value': None}}
|
||||
|
||||
# invalid test case with incorrect rank for indices
|
||||
inputs_inv1 = {'data_data': {'shape': int64_array([10, 40]), 'value': None},
|
||||
'indices_data': {'shape': int64_array([5, 3, 4]), 'value': None}}
|
||||
@@ -123,12 +135,13 @@ inputs_inv3 = {'data_data': {'shape': int64_array([10, 40, 20, 10, 2]), 'value':
|
||||
'indices_data': {'shape': int64_array([10, 40, 4]), 'value': None}}
|
||||
|
||||
|
||||
class TestGatherNDUpdate(unittest.TestCase):
|
||||
class TestGatherND_5(unittest.TestCase):
|
||||
def setUp(self):
|
||||
nodes_attributes['gathernd_node']['batch_dims'] = 0
|
||||
nodes_attributes['gathernd_node']['version'] = 'opset5'
|
||||
|
||||
def test_partial_infer_gather_element(self):
|
||||
graph = build_graph(nodes_attributes, edges, inputs1)
|
||||
graph = build_graph(nodes_attributes, edges, inputs)
|
||||
gathernd_node = Node(graph, 'gathernd_node')
|
||||
GatherND.infer(gathernd_node)
|
||||
|
||||
@@ -142,7 +155,7 @@ class TestGatherNDUpdate(unittest.TestCase):
|
||||
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
|
||||
|
||||
def test_partial_infer_gather_slice(self):
|
||||
graph = build_graph(nodes_attributes, edges, inputs2)
|
||||
graph = build_graph(nodes_attributes, edges, inputs1)
|
||||
gathernd_node = Node(graph, 'gathernd_node')
|
||||
GatherND.infer(gathernd_node)
|
||||
|
||||
@@ -157,7 +170,7 @@ class TestGatherNDUpdate(unittest.TestCase):
|
||||
|
||||
def test_partial_infer_gather_slice_batch_dims2(self):
|
||||
nodes_attributes['gathernd_node']['batch_dims'] = 2
|
||||
graph = build_graph(nodes_attributes, edges, inputs3)
|
||||
graph = build_graph(nodes_attributes, edges, inputs2)
|
||||
gathernd_node = Node(graph, 'gathernd_node')
|
||||
GatherND.infer(gathernd_node)
|
||||
|
||||
@@ -170,7 +183,22 @@ class TestGatherNDUpdate(unittest.TestCase):
|
||||
self.assertTrue(np.array_equal(ref_output_shape, res_output_shape),
|
||||
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
|
||||
|
||||
def test_partial_infer_gather_slice_batch_dims2_dynamic(self):
|
||||
def test_partial_infer_gather_slice_batch_dims3(self):
|
||||
nodes_attributes['gathernd_node']['batch_dims'] = 3
|
||||
graph = build_graph(nodes_attributes, edges, inputs3)
|
||||
gathernd_node = Node(graph, 'gathernd_node')
|
||||
GatherND.infer(gathernd_node)
|
||||
|
||||
# prepare reference results
|
||||
ref_output_shape = int64_array([4096, 1])
|
||||
|
||||
# get the result
|
||||
res_output_shape = graph.node['output']['shape']
|
||||
|
||||
self.assertTrue(np.array_equal(ref_output_shape, res_output_shape),
|
||||
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
|
||||
|
||||
def test_partial_infer_gather_slice_batch_dims2_dynamic1(self):
|
||||
nodes_attributes['gathernd_node']['batch_dims'] = 2
|
||||
graph = build_graph(nodes_attributes, edges, inputs9)
|
||||
gathernd_node = Node(graph, 'gathernd_node')
|
||||
@@ -185,6 +213,36 @@ class TestGatherNDUpdate(unittest.TestCase):
|
||||
self.assertTrue(strict_compare_tensors(ref_output_shape, res_output_shape),
|
||||
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
|
||||
|
||||
def test_partial_infer_gather_slice_batch_dims2_dynamic2(self):
|
||||
nodes_attributes['gathernd_node']['batch_dims'] = 2
|
||||
graph = build_graph(nodes_attributes, edges, inputs10)
|
||||
gathernd_node = Node(graph, 'gathernd_node')
|
||||
GatherND.infer(gathernd_node)
|
||||
|
||||
# prepare reference results
|
||||
ref_output_shape = shape_array([dynamic_dimension_value, 3, 5, 9])
|
||||
|
||||
# get the result
|
||||
res_output_shape = graph.node['output']['shape']
|
||||
|
||||
self.assertTrue(strict_compare_tensors(ref_output_shape, res_output_shape),
|
||||
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
|
||||
|
||||
def test_partial_infer_gather_slice_batch_dims2_dynamic3(self):
|
||||
nodes_attributes['gathernd_node']['batch_dims'] = 2
|
||||
graph = build_graph(nodes_attributes, edges, inputs11)
|
||||
gathernd_node = Node(graph, 'gathernd_node')
|
||||
GatherND.infer(gathernd_node)
|
||||
|
||||
# prepare reference results
|
||||
ref_output_shape = shape_array([dynamic_dimension_value, 3, 5, 9])
|
||||
|
||||
# get the result
|
||||
res_output_shape = graph.node['output']['shape']
|
||||
|
||||
self.assertTrue(strict_compare_tensors(ref_output_shape, res_output_shape),
|
||||
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
|
||||
|
||||
def test_infer4(self):
|
||||
graph = build_graph(nodes_attributes, edges, inputs4)
|
||||
gathernd_node = Node(graph, 'gathernd_node')
|
||||
@@ -205,7 +263,7 @@ class TestGatherNDUpdate(unittest.TestCase):
|
||||
res_output_value = graph.node['output']['value']
|
||||
|
||||
self.assertTrue(np.array_equal(output5, res_output_value),
|
||||
'values do not match expected: {} and given: {}'.format(output4, res_output_value))
|
||||
'values do not match expected: {} and given: {}'.format(output5, res_output_value))
|
||||
|
||||
def test_infer6(self):
|
||||
nodes_attributes['gathernd_node']['batch_dims'] = 1
|
||||
@@ -217,7 +275,7 @@ class TestGatherNDUpdate(unittest.TestCase):
|
||||
res_output_value = graph.node['output']['value']
|
||||
|
||||
self.assertTrue(np.array_equal(output6, res_output_value),
|
||||
'values do not match expected: {} and given: {}'.format(output4, res_output_value))
|
||||
'values do not match expected: {} and given: {}'.format(output6, res_output_value))
|
||||
|
||||
def test_infer7(self):
|
||||
nodes_attributes['gathernd_node']['batch_dims'] = 2
|
||||
@@ -228,8 +286,9 @@ class TestGatherNDUpdate(unittest.TestCase):
|
||||
# get the result
|
||||
res_output_value = graph.node['output']['value']
|
||||
|
||||
self.assertTrue(np.array_equal(output7, res_output_value),
|
||||
'values do not match expected: {} and given: {}'.format(output4, res_output_value))
|
||||
output = output7.reshape([6, 1])
|
||||
self.assertTrue(np.array_equal(output, res_output_value),
|
||||
'values do not match expected: {} and given: {}'.format(output, res_output_value))
|
||||
|
||||
def test_infer8(self):
|
||||
nodes_attributes['gathernd_node']['batch_dims'] = 2
|
||||
@@ -241,7 +300,32 @@ class TestGatherNDUpdate(unittest.TestCase):
|
||||
res_output_value = graph.node['output']['value']
|
||||
|
||||
self.assertTrue(np.array_equal(output8, res_output_value),
|
||||
'values do not match expected: {} and given: {}'.format(output4, res_output_value))
|
||||
'values do not match expected: {} and given: {}'.format(output8, res_output_value))
|
||||
|
||||
def test_infer9(self):
|
||||
nodes_attributes['gathernd_node']['batch_dims'] = 2
|
||||
graph = build_graph(nodes_attributes, edges, inputs8)
|
||||
gathernd_node = Node(graph, 'gathernd_node')
|
||||
GatherND.infer(gathernd_node)
|
||||
|
||||
# get the result
|
||||
res_output_value = graph.node['output']['value']
|
||||
|
||||
self.assertTrue(np.array_equal(output8, res_output_value),
|
||||
'values do not match expected: {} and given: {}'.format(output8, res_output_value))
|
||||
|
||||
def test_infer9_opset_5(self):
|
||||
nodes_attributes['gathernd_node']['batch_dims'] = 2
|
||||
graph = build_graph(nodes_attributes, edges, inputs8)
|
||||
gathernd_node = Node(graph, 'gathernd_node')
|
||||
GatherND.infer(gathernd_node)
|
||||
|
||||
# get the result
|
||||
res_output_value = graph.node['output']['value']
|
||||
|
||||
output = output8.reshape([6, 3])
|
||||
self.assertTrue(np.array_equal(output, res_output_value),
|
||||
'values do not match expected: {} and given: {}'.format(output, res_output_value))
|
||||
|
||||
def test_infer_invalid1(self):
|
||||
graph = build_graph(nodes_attributes, edges, inputs_inv1)
|
||||
@@ -259,3 +343,114 @@ class TestGatherNDUpdate(unittest.TestCase):
|
||||
graph = build_graph(nodes_attributes, edges, inputs_inv3)
|
||||
gathernd_node = Node(graph, 'gathernd_node')
|
||||
self.assertRaises(AssertionError, GatherND.infer, gathernd_node)
|
||||
|
||||
|
||||
def test_partial_infer_gather_slice_batch_dims2_opset8(self):
|
||||
nodes_attributes['gathernd_node']['batch_dims'] = 2
|
||||
nodes_attributes['gathernd_node']['version'] = 'opset8'
|
||||
graph = build_graph(nodes_attributes, edges, inputs2)
|
||||
gathernd_node = Node(graph, 'gathernd_node')
|
||||
GatherND.infer(gathernd_node)
|
||||
|
||||
# prepare reference results
|
||||
ref_output_shape = int64_array([10, 40, 3, 5, 9])
|
||||
|
||||
# get the result
|
||||
res_output_shape = graph.node['output']['shape']
|
||||
|
||||
self.assertTrue(np.array_equal(ref_output_shape, res_output_shape),
|
||||
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
|
||||
|
||||
def test_partial_infer_gather_slice_batch_dims3_opset8(self):
|
||||
nodes_attributes['gathernd_node']['batch_dims'] = 3
|
||||
nodes_attributes['gathernd_node']['version'] = 'opset8'
|
||||
graph = build_graph(nodes_attributes, edges, inputs3)
|
||||
gathernd_node = Node(graph, 'gathernd_node')
|
||||
GatherND.infer(gathernd_node)
|
||||
|
||||
# prepare reference results
|
||||
ref_output_shape = int64_array([1, 64, 64, 1])
|
||||
|
||||
# get the result
|
||||
res_output_shape = graph.node['output']['shape']
|
||||
|
||||
self.assertTrue(np.array_equal(ref_output_shape, res_output_shape),
|
||||
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
|
||||
|
||||
def test_partial_infer_gather_slice_batch_dims2_dynamic1_opset8(self):
|
||||
nodes_attributes['gathernd_node']['batch_dims'] = 2
|
||||
nodes_attributes['gathernd_node']['version'] = 'opset8'
|
||||
graph = build_graph(nodes_attributes, edges, inputs9)
|
||||
gathernd_node = Node(graph, 'gathernd_node')
|
||||
GatherND.infer(gathernd_node)
|
||||
|
||||
# prepare reference results
|
||||
ref_output_shape = shape_array([dynamic_dimension_value, 40, 3, 5, 9])
|
||||
|
||||
# get the result
|
||||
res_output_shape = graph.node['output']['shape']
|
||||
|
||||
self.assertTrue(strict_compare_tensors(ref_output_shape, res_output_shape),
|
||||
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
|
||||
|
||||
def test_partial_infer_gather_slice_batch_dims2_dynamic2_opset8(self):
|
||||
nodes_attributes['gathernd_node']['batch_dims'] = 2
|
||||
nodes_attributes['gathernd_node']['version'] = 'opset8'
|
||||
graph = build_graph(nodes_attributes, edges, inputs10)
|
||||
gathernd_node = Node(graph, 'gathernd_node')
|
||||
GatherND.infer(gathernd_node)
|
||||
|
||||
# prepare reference results
|
||||
ref_output_shape = shape_array([40, dynamic_dimension_value, 3, 5, 9])
|
||||
|
||||
# get the result
|
||||
res_output_shape = graph.node['output']['shape']
|
||||
|
||||
self.assertTrue(strict_compare_tensors(ref_output_shape, res_output_shape),
|
||||
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
|
||||
|
||||
def test_partial_infer_gather_slice_batch_dims2_dynamic3_opset8(self):
|
||||
nodes_attributes['gathernd_node']['batch_dims'] = 2
|
||||
nodes_attributes['gathernd_node']['version'] = 'opset8'
|
||||
graph = build_graph(nodes_attributes, edges, inputs11)
|
||||
gathernd_node = Node(graph, 'gathernd_node')
|
||||
GatherND.infer(gathernd_node)
|
||||
|
||||
# prepare reference results
|
||||
ref_output_shape = shape_array([40, 40, 3, 5, 9])
|
||||
|
||||
# get the result
|
||||
res_output_shape = graph.node['output']['shape']
|
||||
|
||||
self.assertTrue(strict_compare_tensors(ref_output_shape, res_output_shape),
|
||||
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
|
||||
|
||||
def test_infer7_opset8(self):
|
||||
nodes_attributes['gathernd_node']['batch_dims'] = 2
|
||||
nodes_attributes['gathernd_node']['version'] = 'opset8'
|
||||
graph = build_graph(nodes_attributes, edges, inputs7)
|
||||
gathernd_node = Node(graph, 'gathernd_node')
|
||||
GatherND.infer(gathernd_node)
|
||||
|
||||
# get the result
|
||||
res_output_value = graph.node['output']['value']
|
||||
|
||||
output = output7.reshape([2, 3, 1])
|
||||
|
||||
self.assertTrue(np.array_equal(output, res_output_value),
|
||||
'values do not match expected: {} and given: {}'.format(output, res_output_value))
|
||||
|
||||
def test_infer8_opset8(self):
|
||||
nodes_attributes['gathernd_node']['batch_dims'] = 2
|
||||
nodes_attributes['gathernd_node']['version'] = 'opset8'
|
||||
graph = build_graph(nodes_attributes, edges, inputs8)
|
||||
gathernd_node = Node(graph, 'gathernd_node')
|
||||
GatherND.infer(gathernd_node)
|
||||
|
||||
# get the result
|
||||
res_output_value = graph.node['output']['value']
|
||||
|
||||
output = output8.reshape([2, 3, 3])
|
||||
|
||||
self.assertTrue(np.array_equal(output, res_output_value),
|
||||
'values do not match expected: {} and given: {}'.format(output, res_output_value))
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/gather_nd_base.hpp"
|
||||
#include "openvino/op/gather_nd.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
@@ -12,5 +12,8 @@ namespace op {
|
||||
namespace v5 {
|
||||
using ov::op::v5::GatherND;
|
||||
} // namespace v5
|
||||
namespace v8 {
|
||||
using ov::op::v8::GatherND;
|
||||
} // namespace v8
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
||||
16
ngraph/core/include/ngraph/op/util/gather_nd_base.hpp
Normal file
16
ngraph/core/include/ngraph/op/util/gather_nd_base.hpp
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "openvino/op/util/gather_nd_base.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
using ov::op::util::GatherNDBase;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
@@ -4,16 +4,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/op/op.hpp"
|
||||
#include "openvino/op/util/gather_nd_base.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace v5 {
|
||||
/// \brief GatherND operation
|
||||
///
|
||||
class OPENVINO_API GatherND : public Op {
|
||||
class OPENVINO_API GatherND : public op::util::GatherNDBase {
|
||||
public:
|
||||
OPENVINO_OP("GatherND", "opset5", op::Op, 5);
|
||||
OPENVINO_OP("GatherND", "opset5", op::util::GatherNDBase, 5);
|
||||
BWDCMP_RTTI_DECLARATION;
|
||||
GatherND() = default;
|
||||
|
||||
@@ -28,14 +27,30 @@ public:
|
||||
void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
size_t get_batch_dims() const {
|
||||
return m_batch_dims;
|
||||
}
|
||||
|
||||
private:
|
||||
size_t m_batch_dims;
|
||||
};
|
||||
} // namespace v5
|
||||
|
||||
namespace v8 {
|
||||
/// \brief GatherND operation
|
||||
///
|
||||
class OPENVINO_API GatherND : public op::util::GatherNDBase {
|
||||
public:
|
||||
OPENVINO_OP("GatherND", "opset8", op::util::GatherNDBase);
|
||||
BWDCMP_RTTI_DECLARATION;
|
||||
GatherND() = default;
|
||||
|
||||
/// \brief Constructs a GatherND operation.
|
||||
///
|
||||
/// \param data Node producing data that are gathered
|
||||
/// \param indices Node producing indices by which the operation gathers elements
|
||||
/// or slices from data
|
||||
/// \param batch_dims Specifies a number of batch dimensions
|
||||
GatherND(const Output<Node>& data, const Output<Node>& indices, const size_t batch_dims = 0);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
};
|
||||
} // namespace v8
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
|
||||
38
ngraph/core/include/openvino/op/util/gather_nd_base.hpp
Normal file
38
ngraph/core/include/openvino/op/util/gather_nd_base.hpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/op/op.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace util {
|
||||
/// \brief GatherNDBase basic class for GatherND v5 and v8
|
||||
class OPENVINO_API GatherNDBase : public Op {
|
||||
public:
|
||||
OPENVINO_OP("GatherNDBase", "util");
|
||||
BWDCMP_RTTI_DECLARATION;
|
||||
GatherNDBase() = default;
|
||||
|
||||
/// \brief Constructs a GatherND operation.
|
||||
///
|
||||
/// \param data Node producing data that are gathered
|
||||
/// \param indices Node producing indices by which the operation gathers elements
|
||||
/// or slices from data
|
||||
/// \param batch_dims Specifies a leading number of dimensions representing the batches
|
||||
GatherNDBase(const Output<Node>& data, const Output<Node>& indices, const size_t batch_dims = 0);
|
||||
|
||||
size_t get_batch_dims() const {
|
||||
return m_batch_dims;
|
||||
}
|
||||
|
||||
void validate_inputs_and_infer_shape();
|
||||
|
||||
protected:
|
||||
size_t m_batch_dims = 0;
|
||||
};
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
@@ -144,7 +144,6 @@ _OPENVINO_OP_REG(SoftPlus, ov::op::v4)
|
||||
_OPENVINO_OP_REG(Swish, ov::op::v4)
|
||||
|
||||
// New operations added in opset5
|
||||
_OPENVINO_OP_REG(GatherND, ov::op::v5)
|
||||
_OPENVINO_OP_REG(GRUSequence, ov::op::v5)
|
||||
_OPENVINO_OP_REG(HSigmoid, ov::op::v5)
|
||||
_OPENVINO_OP_REG(LogSoftmax, ov::op::v5)
|
||||
@@ -175,6 +174,7 @@ _OPENVINO_OP_REG(Roll, ov::op::v7)
|
||||
|
||||
// New operations added in opset8
|
||||
_OPENVINO_OP_REG(Gather, ov::op::v8)
|
||||
_OPENVINO_OP_REG(GatherND, ov::op::v8)
|
||||
_OPENVINO_OP_REG(AdaptiveAvgPool, ov::op::v8)
|
||||
_OPENVINO_OP_REG(AdaptiveMaxPool, ov::op::v8)
|
||||
_OPENVINO_OP_REG(DeformableConvolution, ov::op::v8)
|
||||
|
||||
@@ -15,99 +15,43 @@ using namespace ngraph;
|
||||
BWDCMP_RTTI_DEFINITION(op::v5::GatherND);
|
||||
|
||||
op::v5::GatherND::GatherND(const Output<Node>& data, const Output<Node>& indices, const size_t batch_dims)
|
||||
: Op({data, indices}),
|
||||
m_batch_dims(batch_dims) {
|
||||
: GatherNDBase(data, indices, batch_dims) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::v5::GatherND::validate_and_infer_types() {
|
||||
NGRAPH_OP_SCOPE(v5_GatherND_validate_and_infer_types);
|
||||
// check types of input tensors
|
||||
const auto& data_type = get_input_element_type(0);
|
||||
const auto& indices_type = get_input_element_type(1);
|
||||
validate_inputs_and_infer_shape();
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
indices_type.is_integral_number(),
|
||||
"The indices type is expected to be an integer type. Got: ",
|
||||
indices_type);
|
||||
// If we have m_batch_dims > 1 we need to fuse batch dimensions of output
|
||||
if (m_batch_dims > 1) {
|
||||
const auto& output_pshape = get_output_partial_shape(0);
|
||||
const auto& data_type = get_input_element_type(0);
|
||||
|
||||
// check ranks of input tensors
|
||||
const auto& data_pshape = get_input_partial_shape(0);
|
||||
const auto& indices_pshape = get_input_partial_shape(1);
|
||||
|
||||
if (data_pshape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(this, data_pshape.rank().get_length() > 0, "Data rank must be at least 1.");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
data_pshape.rank().get_length() > static_cast<int64_t>(m_batch_dims),
|
||||
"Number of batch dimensions must not exceed a rank of data.");
|
||||
}
|
||||
|
||||
if (indices_pshape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(this, indices_pshape.rank().get_length() > 0, "Indices rank must be at least 1.");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
indices_pshape.rank().get_length() > static_cast<int64_t>(m_batch_dims),
|
||||
"Number of batch dimensions must not exceed a rank of indices.");
|
||||
}
|
||||
|
||||
if (data_pshape.rank().is_static() && indices_pshape.rank().is_static()) {
|
||||
// check that batch dimensions of data and indices are the same
|
||||
for (size_t batch_dim = 0; batch_dim < m_batch_dims; batch_dim++) {
|
||||
if (data_pshape[batch_dim].is_static() && indices_pshape[batch_dim].is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
data_pshape[batch_dim].get_length() == indices_pshape[batch_dim].get_length(),
|
||||
"Batch dimensions of data and indices must be the same.");
|
||||
}
|
||||
}
|
||||
|
||||
if (indices_pshape[indices_pshape.rank().get_length() - 1].is_static()) {
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
static_cast<int64_t>(indices_pshape[indices_pshape.rank().get_length() - 1].get_length() +
|
||||
m_batch_dims) <= data_pshape.rank().get_length(),
|
||||
"Length of a tuple with indices must not exceed a rank of data tensor "
|
||||
"excluding "
|
||||
"batch dimensions.");
|
||||
}
|
||||
}
|
||||
|
||||
// set output shape
|
||||
set_output_size(1);
|
||||
if (data_pshape.rank().is_static() && indices_pshape.rank().is_static() &&
|
||||
indices_pshape[indices_pshape.rank().get_length() - 1].is_static()) {
|
||||
auto indices_tuple_length = indices_pshape[indices_pshape.rank().get_length() - 1].get_length();
|
||||
int64_t slice_length = data_pshape.rank().get_length() - indices_tuple_length - m_batch_dims;
|
||||
int64_t output_indices_length = indices_pshape.rank().get_length() - m_batch_dims - 1;
|
||||
auto output_rank = output_indices_length + slice_length;
|
||||
size_t delta_output_rank = 0;
|
||||
if (m_batch_dims > 0) {
|
||||
delta_output_rank = 1;
|
||||
}
|
||||
std::vector<Dimension> output_shape(output_rank + delta_output_rank);
|
||||
if (m_batch_dims > 0) {
|
||||
if (output_pshape.rank().is_static()) {
|
||||
const auto& out_size = output_pshape.size();
|
||||
std::vector<Dimension> output_shape(out_size - m_batch_dims + 1);
|
||||
output_shape[0] = 1;
|
||||
for (size_t dim = 0; dim < m_batch_dims; dim++) {
|
||||
if (data_pshape[dim].is_static()) {
|
||||
output_shape[0] *= data_pshape[dim].get_length();
|
||||
} else if (indices_pshape[dim].is_static()) {
|
||||
output_shape[0] *= indices_pshape[dim].get_length();
|
||||
if (output_pshape[dim].is_static()) {
|
||||
output_shape[0] *= output_pshape[dim].get_length();
|
||||
} else {
|
||||
output_shape[0] = Dimension::dynamic();
|
||||
break;
|
||||
}
|
||||
}
|
||||
size_t ind = 1;
|
||||
for (size_t dim = m_batch_dims; dim < out_size; dim++) {
|
||||
if (output_pshape[dim].is_static()) {
|
||||
output_shape[ind] = output_pshape[dim].get_length();
|
||||
} else {
|
||||
output_shape[ind] = Dimension::dynamic();
|
||||
}
|
||||
ind++;
|
||||
}
|
||||
|
||||
set_output_type(0, data_type, ov::PartialShape(output_shape));
|
||||
}
|
||||
for (int64_t dim = 0; dim < output_indices_length; dim++) {
|
||||
output_shape[dim + delta_output_rank] = indices_pshape[dim + m_batch_dims];
|
||||
}
|
||||
for (int64_t dim = 0; dim < slice_length; dim++) {
|
||||
output_shape[output_indices_length + dim + delta_output_rank] =
|
||||
data_pshape[m_batch_dims + indices_tuple_length + dim];
|
||||
}
|
||||
set_output_type(0, data_type, ov::PartialShape(output_shape));
|
||||
} else {
|
||||
set_output_type(0, data_type, ov::PartialShape::dynamic());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,3 +66,28 @@ shared_ptr<Node> op::v5::GatherND::clone_with_new_inputs(const OutputVector& new
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<op::v5::GatherND>(new_args.at(0), new_args.at(1), m_batch_dims);
|
||||
}
|
||||
|
||||
// ------------------------------ V8 ------------------------------
|
||||
BWDCMP_RTTI_DEFINITION(op::v8::GatherND);
|
||||
|
||||
op::v8::GatherND::GatherND(const Output<Node>& data, const Output<Node>& indices, const size_t batch_dims)
|
||||
: GatherNDBase(data, indices, batch_dims) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::v8::GatherND::validate_and_infer_types() {
|
||||
NGRAPH_OP_SCOPE(v8_GatherND_validate_and_infer_types);
|
||||
validate_inputs_and_infer_shape();
|
||||
}
|
||||
|
||||
bool op::v8::GatherND::visit_attributes(AttributeVisitor& visitor) {
|
||||
NGRAPH_OP_SCOPE(v8_GatherND_visit_attributes);
|
||||
visitor.on_attribute("batch_dims", m_batch_dims);
|
||||
return true;
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v8::GatherND::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
NGRAPH_OP_SCOPE(v8_GatherND_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<op::v8::GatherND>(new_args.at(0), new_args.at(1), m_batch_dims);
|
||||
}
|
||||
|
||||
109
ngraph/core/src/op/util/gather_nd_base.cpp
Normal file
109
ngraph/core/src/op/util/gather_nd_base.cpp
Normal file
@@ -0,0 +1,109 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph/op/util/gather_nd_base.hpp"
|
||||
|
||||
#include <ngraph/validation_util.hpp>
|
||||
|
||||
#include "ngraph/op/concat.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/squeeze.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
|
||||
using namespace std;
|
||||
|
||||
BWDCMP_RTTI_DEFINITION(ov::op::util::GatherNDBase);
|
||||
|
||||
ov::op::util::GatherNDBase::GatherNDBase(const Output<Node>& data, const Output<Node>& indices, const size_t batch_dims)
|
||||
: Op({data, indices}),
|
||||
m_batch_dims(batch_dims) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void ov::op::util::GatherNDBase::validate_inputs_and_infer_shape() {
|
||||
// check types of input tensors
|
||||
const auto& data_type = get_input_element_type(0);
|
||||
const auto& indices_type = get_input_element_type(1);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
indices_type.is_integral_number(),
|
||||
"The indices type is expected to be an integer type. Got: ",
|
||||
indices_type);
|
||||
|
||||
// check ranks of input tensors
|
||||
const auto& data_pshape = get_input_partial_shape(0);
|
||||
const auto& indices_pshape = get_input_partial_shape(1);
|
||||
|
||||
if (data_pshape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(this, data_pshape.rank().get_length() > 0, "Data rank must be at least 1.");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
data_pshape.rank().get_length() > static_cast<int64_t>(m_batch_dims),
|
||||
"Number of batch dimensions must not exceed a rank of data.");
|
||||
}
|
||||
|
||||
if (indices_pshape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(this, indices_pshape.rank().get_length() > 0, "Indices rank must be at least 1.");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
indices_pshape.rank().get_length() > static_cast<int64_t>(m_batch_dims),
|
||||
"Number of batch dimensions must not exceed a rank of indices.");
|
||||
}
|
||||
|
||||
if (data_pshape.rank().is_static() && indices_pshape.rank().is_static()) {
|
||||
// check that batch dimensions of data and indices are the same
|
||||
for (size_t batch_dim = 0; batch_dim < m_batch_dims; batch_dim++) {
|
||||
if (data_pshape[batch_dim].is_static() && indices_pshape[batch_dim].is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
data_pshape[batch_dim].get_length() == indices_pshape[batch_dim].get_length(),
|
||||
"Batch dimensions of data and indices must be the same.");
|
||||
}
|
||||
}
|
||||
|
||||
if (indices_pshape[indices_pshape.rank().get_length() - 1].is_static()) {
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
static_cast<int64_t>(indices_pshape[indices_pshape.rank().get_length() - 1].get_length() +
|
||||
m_batch_dims) <= data_pshape.rank().get_length(),
|
||||
"Length of a tuple with indices must not exceed a rank of data tensor "
|
||||
"excluding "
|
||||
"batch dimensions.");
|
||||
}
|
||||
}
|
||||
|
||||
// set output shape
|
||||
set_output_size(1);
|
||||
if (data_pshape.rank().is_static() && indices_pshape.rank().is_static() &&
|
||||
indices_pshape[indices_pshape.rank().get_length() - 1].is_static()) {
|
||||
auto indices_tuple_length = indices_pshape[indices_pshape.rank().get_length() - 1].get_length();
|
||||
int64_t slice_length = data_pshape.rank().get_length() - indices_tuple_length - m_batch_dims;
|
||||
int64_t output_indices_length = indices_pshape.rank().get_length() - m_batch_dims - 1;
|
||||
auto output_rank = output_indices_length + slice_length;
|
||||
size_t delta_output_rank = 0;
|
||||
delta_output_rank = m_batch_dims;
|
||||
std::vector<Dimension> output_shape(output_rank + delta_output_rank);
|
||||
for (size_t dim = 0; dim < m_batch_dims; dim++) {
|
||||
output_shape[dim] = 1;
|
||||
if (data_pshape[dim].is_static()) {
|
||||
output_shape[dim] = data_pshape[dim].get_length();
|
||||
} else if (indices_pshape[dim].is_static()) {
|
||||
output_shape[dim] = indices_pshape[dim].get_length();
|
||||
} else {
|
||||
output_shape[dim] = Dimension::dynamic();
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (int64_t dim = 0; dim < output_indices_length; dim++) {
|
||||
output_shape[dim + delta_output_rank] = indices_pshape[dim + m_batch_dims];
|
||||
}
|
||||
for (int64_t dim = 0; dim < slice_length; dim++) {
|
||||
output_shape[output_indices_length + dim + delta_output_rank] =
|
||||
data_pshape[m_batch_dims + indices_tuple_length + dim];
|
||||
}
|
||||
set_output_type(0, data_type, ov::PartialShape(output_shape));
|
||||
} else {
|
||||
set_output_type(0, data_type, ov::PartialShape::dynamic());
|
||||
}
|
||||
}
|
||||
@@ -316,3 +316,135 @@ TEST(type_prop, gather_nd_fail_indices_element_type) {
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------ V8 ------------------------------
|
||||
|
||||
TEST(type_prop, gather_nd_8_slices_from_4d_batch_dims0) {
|
||||
Shape params_shape{2, 3, 11, 12};
|
||||
Shape indices_shape{2, 3, 2};
|
||||
Shape out_shape{2, 3, 11, 12};
|
||||
auto P = make_shared<op::Parameter>(element::f32, params_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto G5 = make_shared<op::v8::GatherND>(P, I, 0);
|
||||
ASSERT_EQ(G5->get_element_type(), element::f32);
|
||||
ASSERT_EQ(G5->get_shape(), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_nd_8_scalars_from_4d_batch_dims2) {
|
||||
Shape params_shape{2, 3, 11, 12};
|
||||
Shape indices_shape{2, 3, 2};
|
||||
Shape out_shape{2, 3};
|
||||
auto P = make_shared<op::Parameter>(element::f32, params_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto G5 = make_shared<op::v8::GatherND>(P, I, 2);
|
||||
ASSERT_EQ(G5->get_element_type(), element::f32);
|
||||
ASSERT_EQ(G5->get_shape(), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_nd_8_slices_from_5d_batch_dims2) {
|
||||
Shape params_shape{7, 5, 11, 12, 32};
|
||||
Shape indices_shape{7, 5, 3, 1};
|
||||
Shape out_shape{7, 5, 3, 12, 32};
|
||||
auto P = make_shared<op::Parameter>(element::f32, params_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto G5 = make_shared<op::v8::GatherND>(P, I, 2);
|
||||
ASSERT_EQ(G5->get_element_type(), element::f32);
|
||||
ASSERT_EQ(G5->get_shape(), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_nd_8_batch_dim2_with_dyn_dim) {
|
||||
PartialShape params_shape{7, Dimension::dynamic(), 11, 12, 32};
|
||||
Shape indices_shape{7, 5, 3, 1};
|
||||
Shape out_shape{7, 5, 3, 12, 32};
|
||||
auto P = make_shared<op::Parameter>(element::f32, params_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto G5 = make_shared<op::v8::GatherND>(P, I, 2);
|
||||
ASSERT_EQ(G5->get_element_type(), element::f32);
|
||||
ASSERT_EQ(G5->get_shape(), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_nd_8_batch_dim2_with_dyn_dim2) {
|
||||
PartialShape params_shape{7, Dimension::dynamic(), Dimension::dynamic(), 12, 32};
|
||||
Shape indices_shape{7, 5, 3, 1};
|
||||
Shape out_shape{7, 5, 3, 12, 32};
|
||||
auto P = make_shared<op::Parameter>(element::f32, params_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto G5 = make_shared<op::v8::GatherND>(P, I, 2);
|
||||
ASSERT_EQ(G5->get_element_type(), element::f32);
|
||||
ASSERT_EQ(G5->get_shape(), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_nd_8_batch_dim2_with_dyn_dim3) {
|
||||
PartialShape params_shape{7, Dimension::dynamic(), Dimension::dynamic(), 12, Dimension::dynamic()};
|
||||
Shape indices_shape{7, 5, 3, 1};
|
||||
PartialShape out_shape{7, 5, 3, 12, Dimension::dynamic()};
|
||||
auto P = make_shared<op::Parameter>(element::f32, params_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto G5 = make_shared<op::v8::GatherND>(P, I, 2);
|
||||
ASSERT_EQ(G5->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(G5->get_output_partial_shape(0).same_scheme(out_shape));
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_nd_8_batch_dim0_with_dyn_ind_dim) {
|
||||
PartialShape params_shape{7, Dimension::dynamic(), Dimension::dynamic(), 12, Dimension::dynamic()};
|
||||
PartialShape indices_shape{7, 5, 3, Dimension::dynamic()};
|
||||
auto P = make_shared<op::Parameter>(element::f32, params_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
auto G5 = make_shared<op::v8::GatherND>(P, I, 0);
|
||||
ASSERT_EQ(G5->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(G5->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_nd_8_fail_batch_dims_greater_indices_rank) {
|
||||
Shape params_shape{2, 3, 4, 5};
|
||||
Shape indices_shape{2, 1};
|
||||
auto P = make_shared<op::Parameter>(element::f32, params_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
|
||||
try {
|
||||
auto G5 = make_shared<op::v8::GatherND>(P, I, 3);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect indices rank";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Number of batch dimensions must not exceed a rank of indices."));
|
||||
} catch (...) {
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_nd_8_fail_unequal_batch_dims) {
|
||||
Shape params_shape{2, 3, 4, 5};
|
||||
Shape indices_shape{2, 1, 4};
|
||||
auto P = make_shared<op::Parameter>(element::f32, params_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
|
||||
try {
|
||||
auto G5 = make_shared<op::v8::GatherND>(P, I, 2);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect indices rank";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Batch dimensions of data and indices must be the same."));
|
||||
} catch (...) {
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_nd_8_fail_indices_tuple_greater_data_rank_batch_dims2) {
|
||||
Shape params_shape{2, 1, 4, 5};
|
||||
Shape indices_shape{2, 1, 5, 3};
|
||||
auto P = make_shared<op::Parameter>(element::f32, params_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
|
||||
try {
|
||||
auto G5 = make_shared<op::v8::GatherND>(P, I, 2);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect indices rank";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Length of a tuple with indices must not exceed a rank of "
|
||||
"data tensor excluding batch dimensions."));
|
||||
} catch (...) {
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
#include "ngraph/opsets/opset5.hpp"
|
||||
#include "ngraph/opsets/opset8.hpp"
|
||||
#include "util/visitor.hpp"
|
||||
|
||||
using namespace std;
|
||||
@@ -13,7 +14,7 @@ using namespace ngraph;
|
||||
using ngraph::test::NodeBuilder;
|
||||
using ngraph::test::ValueMap;
|
||||
|
||||
TEST(attributes, gather_nd_op) {
|
||||
TEST(attributes, gather_nd_v5_op) {
|
||||
NodeBuilder::get_ops().register_factory<opset5::GatherND>();
|
||||
int batch_dims = 1;
|
||||
auto P = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4});
|
||||
@@ -25,3 +26,16 @@ TEST(attributes, gather_nd_op) {
|
||||
|
||||
EXPECT_EQ(g_G->get_batch_dims(), G->get_batch_dims());
|
||||
}
|
||||
|
||||
TEST(attributes, gather_nd_v8_op) {
|
||||
NodeBuilder::get_ops().register_factory<opset8::GatherND>();
|
||||
int batch_dims = 1;
|
||||
auto P = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4});
|
||||
auto I = make_shared<op::Parameter>(element::i32, Shape{2, 1});
|
||||
auto G = make_shared<op::v8::GatherND>(P, I, batch_dims);
|
||||
|
||||
NodeBuilder builder(G);
|
||||
auto g_G = ov::as_type_ptr<opset8::GatherND>(builder.create());
|
||||
|
||||
EXPECT_EQ(g_G->get_batch_dims(), G->get_batch_dims());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user