[MO|nGraph]GatherND_8 (#7743)

* Add GatherND_8 operation

* Update shape infer function and tests

* Initial commit for nGraph GatherND_8 operation

* Add GatherNDBase class implementation

* Fix base class errors

* Add missrd header

* Update base class

* Update GatherND_8 implementation

* Fix codestyle

* Fix wrong rank

* Implement tests for gatherND_8 shape inference function

* fix codestyle

* Add limitation to doc

* Siplyfy check in shape inference

* Add more test cases

* Update shape inference function

* Add more test cases to cover all case with dynamic input shapes

* Update shape inference function

* Refactor tests

* Add visitor tests for gatherND_8 operation

* Correct comment

* Add additional check is shape inference function

* Update shape inference implementation for gathernd operartion

* Fix codestyle

* Remove restriction for data is fully defined

* Update shape inference functon

* Fix missed check for nonetype

* Remove redundant checks for batch_dims

* Fix codestyle
This commit is contained in:
Anton Chetverikov
2021-11-10 11:54:52 +03:00
committed by GitHub
parent 76994c6ec9
commit c8e1c8e3eb
12 changed files with 644 additions and 139 deletions

View File

@@ -14,7 +14,7 @@ nodes_attributes = {'data': {'kind': 'op'},
'data_data': {'shape': None, 'value': None, 'kind': 'data'},
'indices': {'kind': 'op'},
'indices_data': {'shape': None, 'value': None, 'kind': 'data'},
'gathernd_node': {'op': 'GatherNDUpdate', 'kind': 'op', 'batch_dims': 0},
'gathernd_node': {'op': 'GatherNDUpdate', 'kind': 'op', 'batch_dims': 0, 'version': 'opset8'},
'output': {'shape': None, 'value': None, 'kind': 'data'}}
# graph 1
@@ -25,17 +25,21 @@ edges = [('data', 'data_data', {'in': 0}),
('gathernd_node', 'output', {'out': 0})]
# test data for partial infer: gather elements
inputs1 = {'data_data': {'shape': int64_array([10, 40]), 'value': None},
inputs = {'data_data': {'shape': int64_array([10, 40]), 'value': None},
'indices_data': {'shape': int64_array([3, 2]), 'value': None}}
# test data for partial infer: gather slices
inputs2 = {'data_data': {'shape': int64_array([10, 40, 30]), 'value': None},
inputs1 = {'data_data': {'shape': int64_array([10, 40, 30]), 'value': None},
'indices_data': {'shape': int64_array([3, 2]), 'value': None}}
# test data for partial infer: gather slices and batch_dims=2
inputs3 = {'data_data': {'shape': int64_array([10, 40, 4, 9]), 'value': None},
inputs2 = {'data_data': {'shape': int64_array([10, 40, 4, 9]), 'value': None},
'indices_data': {'shape': int64_array([10, 40, 3, 5, 1]), 'value': None}}
# test data for partial infer: gather slices and batch_dims=3 and indices.shape[-1]=len(data.shape)-batch_dims
inputs3 = {'data_data': {'shape': int64_array([1, 64, 64, 320]), 'value': None},
'indices_data': {'shape': int64_array([1, 64, 64, 1, 1]), 'value': None}}
# test data for constant folding: gather elements, batch_dims = 0
inputs4 = {'data_data': {'shape': int64_array([2, 2]), 'value': int64_array([[1, 2],
[3, 4]])},
@@ -110,6 +114,14 @@ output8 = int64_array([[3, 8, 6],
inputs9 = {'data_data': {'shape': shape_array([dynamic_dimension_value, 40, 4, 9]), 'value': None},
'indices_data': {'shape': shape_array([dynamic_dimension_value, 40, 3, 5, 1]), 'value': None}}
# test data for partial infer: gather slices and batch_dims=2
inputs10 = {'data_data': {'shape': shape_array([40, dynamic_dimension_value, 4, 9]), 'value': None},
'indices_data': {'shape': shape_array([40, dynamic_dimension_value, 3, 5, 1]), 'value': None}}
# test data for partial infer: gather slices and batch_dims=2
inputs11 = {'data_data': {'shape': shape_array([dynamic_dimension_value, 40, 4, 9]), 'value': None},
'indices_data': {'shape': shape_array([40, dynamic_dimension_value, 3, 5, 1]), 'value': None}}
# invalid test case with incorrect rank for indices
inputs_inv1 = {'data_data': {'shape': int64_array([10, 40]), 'value': None},
'indices_data': {'shape': int64_array([5, 3, 4]), 'value': None}}
@@ -123,12 +135,13 @@ inputs_inv3 = {'data_data': {'shape': int64_array([10, 40, 20, 10, 2]), 'value':
'indices_data': {'shape': int64_array([10, 40, 4]), 'value': None}}
class TestGatherNDUpdate(unittest.TestCase):
class TestGatherND_5(unittest.TestCase):
def setUp(self):
nodes_attributes['gathernd_node']['batch_dims'] = 0
nodes_attributes['gathernd_node']['version'] = 'opset5'
def test_partial_infer_gather_element(self):
graph = build_graph(nodes_attributes, edges, inputs1)
graph = build_graph(nodes_attributes, edges, inputs)
gathernd_node = Node(graph, 'gathernd_node')
GatherND.infer(gathernd_node)
@@ -142,7 +155,7 @@ class TestGatherNDUpdate(unittest.TestCase):
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
def test_partial_infer_gather_slice(self):
graph = build_graph(nodes_attributes, edges, inputs2)
graph = build_graph(nodes_attributes, edges, inputs1)
gathernd_node = Node(graph, 'gathernd_node')
GatherND.infer(gathernd_node)
@@ -157,7 +170,7 @@ class TestGatherNDUpdate(unittest.TestCase):
def test_partial_infer_gather_slice_batch_dims2(self):
nodes_attributes['gathernd_node']['batch_dims'] = 2
graph = build_graph(nodes_attributes, edges, inputs3)
graph = build_graph(nodes_attributes, edges, inputs2)
gathernd_node = Node(graph, 'gathernd_node')
GatherND.infer(gathernd_node)
@@ -170,7 +183,22 @@ class TestGatherNDUpdate(unittest.TestCase):
self.assertTrue(np.array_equal(ref_output_shape, res_output_shape),
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
def test_partial_infer_gather_slice_batch_dims2_dynamic(self):
def test_partial_infer_gather_slice_batch_dims3(self):
nodes_attributes['gathernd_node']['batch_dims'] = 3
graph = build_graph(nodes_attributes, edges, inputs3)
gathernd_node = Node(graph, 'gathernd_node')
GatherND.infer(gathernd_node)
# prepare reference results
ref_output_shape = int64_array([4096, 1])
# get the result
res_output_shape = graph.node['output']['shape']
self.assertTrue(np.array_equal(ref_output_shape, res_output_shape),
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
def test_partial_infer_gather_slice_batch_dims2_dynamic1(self):
nodes_attributes['gathernd_node']['batch_dims'] = 2
graph = build_graph(nodes_attributes, edges, inputs9)
gathernd_node = Node(graph, 'gathernd_node')
@@ -185,6 +213,36 @@ class TestGatherNDUpdate(unittest.TestCase):
self.assertTrue(strict_compare_tensors(ref_output_shape, res_output_shape),
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
def test_partial_infer_gather_slice_batch_dims2_dynamic2(self):
nodes_attributes['gathernd_node']['batch_dims'] = 2
graph = build_graph(nodes_attributes, edges, inputs10)
gathernd_node = Node(graph, 'gathernd_node')
GatherND.infer(gathernd_node)
# prepare reference results
ref_output_shape = shape_array([dynamic_dimension_value, 3, 5, 9])
# get the result
res_output_shape = graph.node['output']['shape']
self.assertTrue(strict_compare_tensors(ref_output_shape, res_output_shape),
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
def test_partial_infer_gather_slice_batch_dims2_dynamic3(self):
nodes_attributes['gathernd_node']['batch_dims'] = 2
graph = build_graph(nodes_attributes, edges, inputs11)
gathernd_node = Node(graph, 'gathernd_node')
GatherND.infer(gathernd_node)
# prepare reference results
ref_output_shape = shape_array([dynamic_dimension_value, 3, 5, 9])
# get the result
res_output_shape = graph.node['output']['shape']
self.assertTrue(strict_compare_tensors(ref_output_shape, res_output_shape),
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
def test_infer4(self):
graph = build_graph(nodes_attributes, edges, inputs4)
gathernd_node = Node(graph, 'gathernd_node')
@@ -205,7 +263,7 @@ class TestGatherNDUpdate(unittest.TestCase):
res_output_value = graph.node['output']['value']
self.assertTrue(np.array_equal(output5, res_output_value),
'values do not match expected: {} and given: {}'.format(output4, res_output_value))
'values do not match expected: {} and given: {}'.format(output5, res_output_value))
def test_infer6(self):
nodes_attributes['gathernd_node']['batch_dims'] = 1
@@ -217,7 +275,7 @@ class TestGatherNDUpdate(unittest.TestCase):
res_output_value = graph.node['output']['value']
self.assertTrue(np.array_equal(output6, res_output_value),
'values do not match expected: {} and given: {}'.format(output4, res_output_value))
'values do not match expected: {} and given: {}'.format(output6, res_output_value))
def test_infer7(self):
nodes_attributes['gathernd_node']['batch_dims'] = 2
@@ -228,8 +286,9 @@ class TestGatherNDUpdate(unittest.TestCase):
# get the result
res_output_value = graph.node['output']['value']
self.assertTrue(np.array_equal(output7, res_output_value),
'values do not match expected: {} and given: {}'.format(output4, res_output_value))
output = output7.reshape([6, 1])
self.assertTrue(np.array_equal(output, res_output_value),
'values do not match expected: {} and given: {}'.format(output, res_output_value))
def test_infer8(self):
nodes_attributes['gathernd_node']['batch_dims'] = 2
@@ -241,7 +300,32 @@ class TestGatherNDUpdate(unittest.TestCase):
res_output_value = graph.node['output']['value']
self.assertTrue(np.array_equal(output8, res_output_value),
'values do not match expected: {} and given: {}'.format(output4, res_output_value))
'values do not match expected: {} and given: {}'.format(output8, res_output_value))
def test_infer9(self):
nodes_attributes['gathernd_node']['batch_dims'] = 2
graph = build_graph(nodes_attributes, edges, inputs8)
gathernd_node = Node(graph, 'gathernd_node')
GatherND.infer(gathernd_node)
# get the result
res_output_value = graph.node['output']['value']
self.assertTrue(np.array_equal(output8, res_output_value),
'values do not match expected: {} and given: {}'.format(output8, res_output_value))
def test_infer9_opset_5(self):
nodes_attributes['gathernd_node']['batch_dims'] = 2
graph = build_graph(nodes_attributes, edges, inputs8)
gathernd_node = Node(graph, 'gathernd_node')
GatherND.infer(gathernd_node)
# get the result
res_output_value = graph.node['output']['value']
output = output8.reshape([6, 3])
self.assertTrue(np.array_equal(output, res_output_value),
'values do not match expected: {} and given: {}'.format(output, res_output_value))
def test_infer_invalid1(self):
graph = build_graph(nodes_attributes, edges, inputs_inv1)
@@ -259,3 +343,114 @@ class TestGatherNDUpdate(unittest.TestCase):
graph = build_graph(nodes_attributes, edges, inputs_inv3)
gathernd_node = Node(graph, 'gathernd_node')
self.assertRaises(AssertionError, GatherND.infer, gathernd_node)
def test_partial_infer_gather_slice_batch_dims2_opset8(self):
nodes_attributes['gathernd_node']['batch_dims'] = 2
nodes_attributes['gathernd_node']['version'] = 'opset8'
graph = build_graph(nodes_attributes, edges, inputs2)
gathernd_node = Node(graph, 'gathernd_node')
GatherND.infer(gathernd_node)
# prepare reference results
ref_output_shape = int64_array([10, 40, 3, 5, 9])
# get the result
res_output_shape = graph.node['output']['shape']
self.assertTrue(np.array_equal(ref_output_shape, res_output_shape),
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
def test_partial_infer_gather_slice_batch_dims3_opset8(self):
nodes_attributes['gathernd_node']['batch_dims'] = 3
nodes_attributes['gathernd_node']['version'] = 'opset8'
graph = build_graph(nodes_attributes, edges, inputs3)
gathernd_node = Node(graph, 'gathernd_node')
GatherND.infer(gathernd_node)
# prepare reference results
ref_output_shape = int64_array([1, 64, 64, 1])
# get the result
res_output_shape = graph.node['output']['shape']
self.assertTrue(np.array_equal(ref_output_shape, res_output_shape),
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
def test_partial_infer_gather_slice_batch_dims2_dynamic1_opset8(self):
nodes_attributes['gathernd_node']['batch_dims'] = 2
nodes_attributes['gathernd_node']['version'] = 'opset8'
graph = build_graph(nodes_attributes, edges, inputs9)
gathernd_node = Node(graph, 'gathernd_node')
GatherND.infer(gathernd_node)
# prepare reference results
ref_output_shape = shape_array([dynamic_dimension_value, 40, 3, 5, 9])
# get the result
res_output_shape = graph.node['output']['shape']
self.assertTrue(strict_compare_tensors(ref_output_shape, res_output_shape),
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
def test_partial_infer_gather_slice_batch_dims2_dynamic2_opset8(self):
nodes_attributes['gathernd_node']['batch_dims'] = 2
nodes_attributes['gathernd_node']['version'] = 'opset8'
graph = build_graph(nodes_attributes, edges, inputs10)
gathernd_node = Node(graph, 'gathernd_node')
GatherND.infer(gathernd_node)
# prepare reference results
ref_output_shape = shape_array([40, dynamic_dimension_value, 3, 5, 9])
# get the result
res_output_shape = graph.node['output']['shape']
self.assertTrue(strict_compare_tensors(ref_output_shape, res_output_shape),
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
def test_partial_infer_gather_slice_batch_dims2_dynamic3_opset8(self):
nodes_attributes['gathernd_node']['batch_dims'] = 2
nodes_attributes['gathernd_node']['version'] = 'opset8'
graph = build_graph(nodes_attributes, edges, inputs11)
gathernd_node = Node(graph, 'gathernd_node')
GatherND.infer(gathernd_node)
# prepare reference results
ref_output_shape = shape_array([40, 40, 3, 5, 9])
# get the result
res_output_shape = graph.node['output']['shape']
self.assertTrue(strict_compare_tensors(ref_output_shape, res_output_shape),
'values do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
def test_infer7_opset8(self):
nodes_attributes['gathernd_node']['batch_dims'] = 2
nodes_attributes['gathernd_node']['version'] = 'opset8'
graph = build_graph(nodes_attributes, edges, inputs7)
gathernd_node = Node(graph, 'gathernd_node')
GatherND.infer(gathernd_node)
# get the result
res_output_value = graph.node['output']['value']
output = output7.reshape([2, 3, 1])
self.assertTrue(np.array_equal(output, res_output_value),
'values do not match expected: {} and given: {}'.format(output, res_output_value))
def test_infer8_opset8(self):
nodes_attributes['gathernd_node']['batch_dims'] = 2
nodes_attributes['gathernd_node']['version'] = 'opset8'
graph = build_graph(nodes_attributes, edges, inputs8)
gathernd_node = Node(graph, 'gathernd_node')
GatherND.infer(gathernd_node)
# get the result
res_output_value = graph.node['output']['value']
output = output8.reshape([2, 3, 3])
self.assertTrue(np.array_equal(output, res_output_value),
'values do not match expected: {} and given: {}'.format(output, res_output_value))