Incorrect output shape of Gather operation for some models (#8899)

* Fix in the MO infer function of Gather.

* Added comment about done fix.

* Added more tests.

* Now Gather and AttributedGather are always marked as reinterp_shape=True.
This commit is contained in:
Vladimir Gavrilov 2021-12-01 10:22:30 +03:00 committed by GitHub
parent 9ec7bf286e
commit ab22d7d041
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 1 deletions

View File

@ -20,6 +20,7 @@ class Gather(Op):
'type': self.op,
'version': 'opset8',
'batch_dims': 0,
'reinterp_shape': True,
'infer': self.infer,
'force_precision_in_ports': {1: 'int32', 2: 'int64'},
'in_ports_count': 3,
@ -109,7 +110,7 @@ class AttributedGather(Op):
'type': 'Gather',
'axis': 0,
'reinterp_shape': True,
'infer': self.infer,
'force_precision_in_ports': {1: 'int32'},

View File

@ -72,12 +72,48 @@ class TestGatherPartialInfer(unittest.TestCase):
indices_shape=[1, 2],
ref_shape=[3, 1, 2])
def test_shape_axis_1_1(self):
self.build_and_test_shape_inference(axis=1, batch_dims=0,
data_shape=[3, 3],
indices_shape=[1, 2, 4],
ref_shape=[3, 1, 2, 4])
def test_shape_axis_1_2(self):
self.build_and_test_shape_inference(axis=1, batch_dims=0,
data_shape=[1, 2, 4],
indices_shape=[3, 3],
ref_shape=[1, 3, 3, 4])
def test_shape_axis_1_3(self):
self.build_and_test_shape_inference(axis=1, batch_dims=0,
data_shape=[1, 2, 4],
indices_shape=[5, 8, 16],
ref_shape=[1, 5, 8, 16, 4])
def test_shape_axis_0(self):
self.build_and_test_shape_inference(axis=0, batch_dims=0,
data_shape=[3, 3],
indices_shape=[1, 2],
ref_shape=[1, 2, 3])
def test_shape_axis_0_1(self):
self.build_and_test_shape_inference(axis=0, batch_dims=0,
data_shape=[3, 3],
indices_shape=[1, 2, 5],
ref_shape=[1, 2, 5, 3])
def test_shape_axis_0_2(self):
self.build_and_test_shape_inference(axis=0, batch_dims=0,
data_shape=[1, 2, 5],
indices_shape=[3, 3],
ref_shape=[3, 3, 2, 5])
def test_shape_axis_0_3(self):
self.build_and_test_shape_inference(axis=0, batch_dims=0,
data_shape=[1, 2, 5],
indices_shape=[6, 8, 15],
ref_shape=[6, 8, 15, 2, 5])
def test_shape_axis_minus_2(self):
self.build_and_test_shape_inference(axis=-2, batch_dims=0,
data_shape=[2, 3, 7],