Incorrect output shape of Gather operation for some models (#8899)
* Fix in the MO infer function of Gather. * Added comment about done fix. * Added more tests. * Now Gather and AttributedGather are always marked as reinterp_shape=True.
This commit is contained in:
parent
9ec7bf286e
commit
ab22d7d041
@ -20,6 +20,7 @@ class Gather(Op):
|
||||
'type': self.op,
|
||||
'version': 'opset8',
|
||||
'batch_dims': 0,
|
||||
'reinterp_shape': True,
|
||||
'infer': self.infer,
|
||||
'force_precision_in_ports': {1: 'int32', 2: 'int64'},
|
||||
'in_ports_count': 3,
|
||||
@ -109,7 +110,7 @@ class AttributedGather(Op):
|
||||
'type': 'Gather',
|
||||
|
||||
'axis': 0,
|
||||
|
||||
'reinterp_shape': True,
|
||||
'infer': self.infer,
|
||||
|
||||
'force_precision_in_ports': {1: 'int32'},
|
||||
|
@ -72,12 +72,48 @@ class TestGatherPartialInfer(unittest.TestCase):
|
||||
indices_shape=[1, 2],
|
||||
ref_shape=[3, 1, 2])
|
||||
|
||||
def test_shape_axis_1_1(self):
|
||||
self.build_and_test_shape_inference(axis=1, batch_dims=0,
|
||||
data_shape=[3, 3],
|
||||
indices_shape=[1, 2, 4],
|
||||
ref_shape=[3, 1, 2, 4])
|
||||
|
||||
def test_shape_axis_1_2(self):
|
||||
self.build_and_test_shape_inference(axis=1, batch_dims=0,
|
||||
data_shape=[1, 2, 4],
|
||||
indices_shape=[3, 3],
|
||||
ref_shape=[1, 3, 3, 4])
|
||||
|
||||
def test_shape_axis_1_3(self):
|
||||
self.build_and_test_shape_inference(axis=1, batch_dims=0,
|
||||
data_shape=[1, 2, 4],
|
||||
indices_shape=[5, 8, 16],
|
||||
ref_shape=[1, 5, 8, 16, 4])
|
||||
|
||||
def test_shape_axis_0(self):
|
||||
self.build_and_test_shape_inference(axis=0, batch_dims=0,
|
||||
data_shape=[3, 3],
|
||||
indices_shape=[1, 2],
|
||||
ref_shape=[1, 2, 3])
|
||||
|
||||
def test_shape_axis_0_1(self):
|
||||
self.build_and_test_shape_inference(axis=0, batch_dims=0,
|
||||
data_shape=[3, 3],
|
||||
indices_shape=[1, 2, 5],
|
||||
ref_shape=[1, 2, 5, 3])
|
||||
|
||||
def test_shape_axis_0_2(self):
|
||||
self.build_and_test_shape_inference(axis=0, batch_dims=0,
|
||||
data_shape=[1, 2, 5],
|
||||
indices_shape=[3, 3],
|
||||
ref_shape=[3, 3, 2, 5])
|
||||
|
||||
def test_shape_axis_0_3(self):
|
||||
self.build_and_test_shape_inference(axis=0, batch_dims=0,
|
||||
data_shape=[1, 2, 5],
|
||||
indices_shape=[6, 8, 15],
|
||||
ref_shape=[6, 8, 15, 2, 5])
|
||||
|
||||
def test_shape_axis_minus_2(self):
|
||||
self.build_and_test_shape_inference(axis=-2, batch_dims=0,
|
||||
data_shape=[2, 3, 7],
|
||||
|
Loading…
Reference in New Issue
Block a user