Add Gather-7 to MO (#5264)

* initial solution

* added unit-tests + some corrections

* axis getting improvements

* fixed MO IR reader for old IR's

* a couple of corrections

* applied review comments

* corrected negative batch_dims normalization for shape calculation, for IR original negative values are kept

* added additional checks and negative tests
This commit is contained in:
Pavel Esir 2021-05-11 22:29:59 +03:00 committed by GitHub
parent dc22c177d5
commit 0b22d6c51c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 236 additions and 57 deletions

View File

@ -31,5 +31,5 @@ class GatherV2FrontExtractor(FrontExtractorOp):
@classmethod
def extract(cls, node):
Gather.update_node_stat(node, {})
Gather.update_node_stat(node, {'batch_dims': node.pb.attr['batch_dims'].i})
return cls.enabled

View File

@ -7,6 +7,7 @@ from mo.front.caffe.extractors.utils import get_canonical_axis_index
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node, Graph
from mo.ops.op import Op, PermuteAttrs
from mo.utils.error import Error
class Gather(Op):
@ -17,12 +18,10 @@ class Gather(Op):
super().__init__(graph, {
'op': self.op,
'type': self.op,
'version': 'opset1',
'version': 'opset7',
'batch_dims': 0,
'infer': self.infer,
'force_precision_in_ports': {1: 'int32', 2: 'int64'},
'in_ports_count': 3,
'out_ports_count': 1,
}, attrs)
@ -30,6 +29,15 @@ class Gather(Op):
assert 'axis' not in self.attrs, \
'Use AttributedGather operation instead of Gather to create it with `axis` as a parameter'
def backend_attrs(self):
version = self.get_opset()
if version == 'opset7':
return ['batch_dims']
elif version == 'opset1':
return []
else:
raise Error('Unsupported operation opset version "{}"'.format(version))
@staticmethod
def infer(node: Node):
name = node.soft_get('name', node.id)
@ -44,25 +52,46 @@ class Gather(Op):
indices_shape = node.in_port(1).data.get_shape()
assert indices_shape is not None
axis = node.in_port(2).data.get_value()
assert axis is not None
axis = get_canonical_axis_index(data_shape, axis)
assert axis is not None, 'axis input is undefined'
assert -len(data_shape) <= axis < len(data_shape), \
'axis must be within interval [-data_rank, data_rank). Instead got axis = {}, data_rank = {} '.\
format(axis, len(data_shape))
batch_dims = node.batch_dims
assert -len(indices_shape) <= batch_dims <= len(indices_shape), \
'batch_dims must be within interval [-indices_rank, indices_rank]. Instead got batch_dims = {}, ' \
'indices_rank = {} '.format(batch_dims, len(indices_shape))
# normalize to positive values
axis = axis + len(data_shape) if axis < 0 else axis
batch_dims = batch_dims + len(indices_shape) if batch_dims < 0 else batch_dims
assert np.array_equal(data_shape[:batch_dims], indices_shape[:batch_dims]), \
'data and indices inputs must have equal first dimensions until batch_dims'
assert batch_dims <= axis, \
'normalized batch_dims must be <= axis. Instead got batch_dims = {}, axis = {}'.format(axis, batch_dims)
# we import PermuteInputs locally because it uses Gather inside and we have recursive imports
from mo.graph.perm_inputs import PermuteInputs
PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'axis')
batch_dims_range = indices_shape[:batch_dims]
out_shape = np.concatenate((data_shape[:axis], indices_shape[batch_dims:], data_shape[axis + 1:]))
data_value = node.in_port(0).data.get_value()
indices_value = node.in_port(1).data.get_value()
if data_value is not None and indices_value is not None:
node.out_port(0).data.set_value(np.array(np.take(data_value, int64_array(indices_value), axis),
dtype=data_value.dtype))
return
shape = np.concatenate((data_shape[:axis], indices_shape))
if axis < len(data_shape) - 1:
shape = np.concatenate((shape, data_shape[axis + 1:]))
node.out_port(0).data.set_shape(int64_array(shape))
if batch_dims == 0:
node.out_port(0).data.set_value(np.take(data_value, indices_value, axis))
else:
out_value = np.empty(out_shape)
for batch_idx in np.ndindex(tuple(batch_dims_range)):
out_value[batch_idx] = np.take(data_value[batch_idx], indices_value[batch_idx], axis - batch_dims)
node.out_port(0).data.set_value(out_value)
else:
node.out_port(0).data.set_shape(int64_array(out_shape))
class AttributedGather(Op):
@ -80,7 +109,7 @@ class AttributedGather(Op):
'force_precision_in_ports': {1: 'int32'},
'in_ports_count': 3,
'in_ports_count': 2,
'out_ports_count': 1,
}, attrs)

View File

@ -3,59 +3,204 @@
import unittest
import numpy as np
import numpy.testing as npt
from extensions.ops.gather import Gather
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node
from unit_tests.utils.graph import build_graph
from mo.middle.passes.infer import partial_infer
from mo.utils.error import Error
from unit_tests.utils.graph import valued_const_with_data, result, regular_op_with_empty_data, connect, \
shaped_parameter, build_graph
class TestGatherPartialInfer(unittest.TestCase):
@staticmethod
def _create_graph():
nodes_attributes = {
'gather_input': {'kind': 'op'},
'gather_input_data': {'shape': None, 'value': None, 'kind': 'data'},
'gather_input2': {'kind': 'op'},
'gather_input2_data': {'shape': None, 'value': None, 'kind': 'data'},
'gather_input3': {'kind': 'op'},
'gather_input3_data': {'shape': None, 'value': 0, 'kind': 'data'},
'gather_node': {'op': 'Gather', 'kind': 'op'},
'gather_output': {'shape': None, 'value': None, 'kind': 'data'}
def build_and_test_value_inference(data, indices, axis, batch_dims, ref_value, negative_test_string=None):
nodes = {
**valued_const_with_data('data', int64_array(data)),
**valued_const_with_data('indices', int64_array(indices)),
**valued_const_with_data('axis', int64_array(axis)),
**regular_op_with_empty_data('gather', {'op': 'Gather', 'batch_dims': batch_dims, 'infer': Gather.infer}),
**result('res'),
}
return build_graph(nodes_attributes,
[
('gather_input', 'gather_input_data'),
('gather_input2', 'gather_input2_data'),
('gather_input3', 'gather_input3_data'),
('gather_input_data', 'gather_node'),
('gather_input2_data', 'gather_node'),
('gather_input3_data', 'gather_node'),
edges = [
*connect('data', '0:gather'),
*connect('indices', '1:gather'),
*connect('axis', '2:gather'),
*connect('gather', 'res')
]
('gather_node', 'gather_output'),
],
{
'gather_input_data': {'shape': int64_array([10, 15]), 'value': np.ones((3, 15))},
'gather_input2_data': {'shape': int64_array([2]), 'value': np.array([0, 2])},
})
graph = build_graph(nodes, edges)
graph.stage = 'middle'
partial_infer(graph)
def test_gather_infer(self):
graph = self._create_graph()
node = Node(graph, 'gather')
res = node.out_port(0).data.get_value()
npt.assert_array_equal(res, ref_value)
gather_node = Node(graph, 'gather_node')
Gather.infer(gather_node)
@staticmethod
def build_and_test_shape_inference(data_shape, indices_shape, axis, batch_dims, ref_shape):
nodes = {
**shaped_parameter('data', int64_array(data_shape)),
**shaped_parameter('indices', int64_array(indices_shape)),
**valued_const_with_data('axis', int64_array(axis)),
**regular_op_with_empty_data('gather', {'op': 'Gather', 'batch_dims': batch_dims, 'infer': Gather.infer}),
**result('res'),
}
exp_shape = int64_array([2, 15])
res_shape = graph.node['gather_output']['shape']
res_value = graph.node['gather_output']['value']
edges = [
*connect('data', '0:gather'),
*connect('indices', '1:gather'),
*connect('axis', '2:gather'),
*connect('gather', 'res')
]
self.assertTrue(np.array_equal(exp_shape, res_shape),
'shapes do not match expected: {} and given: {}'.format(exp_shape, res_shape))
graph = build_graph(nodes, edges)
graph.stage = 'middle'
partial_infer(graph)
self.assertTrue(np.array_equal(res_value, np.ones(exp_shape)),
'shapes do not match expected: {} and given: {}'.format(exp_shape, res_shape))
node = Node(graph, 'gather')
res = node.out_port(0).data.get_shape()
npt.assert_array_equal(res, ref_shape)
def test_shape_axis_1(self):
self.build_and_test_shape_inference(axis=1, batch_dims=0,
data_shape=[3, 3],
indices_shape=[1, 2],
ref_shape=[3, 1, 2])
def test_shape_axis_0(self):
self.build_and_test_shape_inference(axis=0, batch_dims=0,
data_shape=[3, 3],
indices_shape=[1, 2],
ref_shape=[1, 2, 3])
def test_shape_axis_minus_2(self):
self.build_and_test_shape_inference(axis=-2, batch_dims=0,
data_shape=[2, 3, 7],
indices_shape=[1, 4],
ref_shape=[2, 1, 4, 7])
def test_shape_axis_1_batch_dims_1(self):
self.build_and_test_shape_inference(axis=1, batch_dims=1,
data_shape=[3, 4],
indices_shape=[3, 1, 2],
ref_shape=[3, 1, 2])
def test_shape_axis_2_batch_dims_1(self):
self.build_and_test_shape_inference(axis=2, batch_dims=1,
data_shape=[3, 4, 7],
indices_shape=[3, 1, 2],
ref_shape=[3, 4, 1, 2])
def test_shape_axis_2_batch_dims_minus_1(self):
self.build_and_test_shape_inference(axis=2, batch_dims=-1,
data_shape=[3, 1, 7],
indices_shape=[3, 1, 2],
ref_shape=[3, 1, 2])
def test_shape_axis_2_batch_dims_minus_2(self):
self.build_and_test_shape_inference(axis=2, batch_dims=-2,
data_shape=[3, 4, 7],
indices_shape=[3, 1, 2],
ref_shape=[3, 4, 1, 2])
def test_axis_0_batch_dims_0(self):
self.build_and_test_value_inference(axis=0, batch_dims=0,
data=[1, 2, 3, 4, 5],
indices=[0, 0, 4],
ref_value=[1, 1, 5])
def test_axis_1_batch_dims_1(self):
self.build_and_test_value_inference(axis=1, batch_dims=1,
data=[[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10]],
indices=[[0, 0, 4],
[4, 0, 0]],
ref_value=[[1, 1, 5],
[10, 6, 6]])
def test_axis_minus_1_batch_dims_1(self):
self.build_and_test_value_inference(axis=-1, batch_dims=1,
data=[[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10]],
indices=[[0, 0, 4],
[4, 0, 0]],
ref_value=[[1, 1, 5],
[10, 6, 6]])
def test_axis_2_batch_dims_1(self):
self.build_and_test_value_inference(axis=2, batch_dims=1,
data=[[[[ 1, 2, 3, 4], # <-- first batch
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[13, 14, 15, 16],
[17, 18, 19, 20]]],
[[[21, 22, 23, 24], # < -- second batch
[25, 26, 27, 28],
[29, 30, 31, 32],
[33, 34, 35, 36],
[37, 38, 39, 40]]]], # data_shape = (2, 1, 5, 4)
indices=[[1, 2, 4],
[4, 3, 2]],
ref_value=[[[[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[17, 18, 19, 20]]],
[[[37, 38, 39, 40],
[33, 34, 35, 36],
[29, 30, 31, 32]]]])
def test_axis_2_batch_dims_mimus_1(self):
self.build_and_test_value_inference(axis=2, batch_dims=-1,
data=[[[[ 1, 2, 3, 4], # <-- first batch
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[13, 14, 15, 16],
[17, 18, 19, 20]]],
[[[21, 22, 23, 24], # < -- second batch
[25, 26, 27, 28],
[29, 30, 31, 32],
[33, 34, 35, 36],
[37, 38, 39, 40]]]], # data_shape = (2, 1, 5, 4)
indices=[[1, 2, 4],
[4, 3, 2]],
ref_value=[[[[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[17, 18, 19, 20]]],
[[[37, 38, 39, 40],
[33, 34, 35, 36],
[29, 30, 31, 32]]]])
# negative tests
def test_shape_indices_data_shape_inconsistency(self):
self.assertRaises(Error, self.build_and_test_shape_inference,
axis=2, batch_dims=2,
data_shape=[3, 4, 7],
indices_shape=[3, 1, 2],
ref_shape=[3, 4, 2])
def test_shape_batch_dims_greater_than_axis(self):
self.assertRaises(Error, self.build_and_test_shape_inference,
axis=2, batch_dims=3,
data_shape=[3, 4, 7],
indices_shape=[3, 4, 2],
ref_shape=[3, 4, 2])
def test_shape_batch_dims_out_of_bound(self):
self.assertRaises(Error, self.build_and_test_shape_inference,
axis=2, batch_dims=4,
data_shape=[3, 4, 7],
indices_shape=[3, 4, 2],
ref_shape=[3, 4, 2])
def test_shape_axis_out_of_bound(self):
self.assertRaises(Error, self.build_and_test_shape_inference,
axis=3, batch_dims=2,
data_shape=[3, 4, 7],
indices_shape=[3, 4, 2],
ref_shape=[3, 4, 2])

View File

@ -6,6 +6,7 @@ from copy import deepcopy
import networkx as nx
from extensions.ops.parameter import Parameter
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node, Graph
from mo.middle.pattern_match import all_edges_in_nodes
@ -275,6 +276,10 @@ shaped_data = lambda name, shape: {name: {'kind': 'data', 'value': None,
'shape': int64_array(shape) if shape is not None else None}}
empty_data = lambda name: valued_data(name, None)
shaped_parameter = lambda name, shape: {**regular_op(name, {'op': 'Parameter', 'shape': shape,
'infer': Parameter.infer}),
**shaped_data(name + '_d', shape)}
result = lambda name=None: {name if name is not None else 'output': {'kind': 'op', 'type': 'Result', 'op': 'Result',
'infer': lambda x: 0}}