diff --git a/model-optimizer/extensions/front/tf/gather_ext.py b/model-optimizer/extensions/front/tf/gather_ext.py index 8af73b15502..faad2f22388 100644 --- a/model-optimizer/extensions/front/tf/gather_ext.py +++ b/model-optimizer/extensions/front/tf/gather_ext.py @@ -31,5 +31,5 @@ class GatherV2FrontExtractor(FrontExtractorOp): @classmethod def extract(cls, node): - Gather.update_node_stat(node, {}) + Gather.update_node_stat(node, {'batch_dims': node.pb.attr['batch_dims'].i}) return cls.enabled diff --git a/model-optimizer/extensions/ops/gather.py b/model-optimizer/extensions/ops/gather.py index 3396a639e7e..e5ca196fca7 100644 --- a/model-optimizer/extensions/ops/gather.py +++ b/model-optimizer/extensions/ops/gather.py @@ -7,6 +7,7 @@ from mo.front.caffe.extractors.utils import get_canonical_axis_index from mo.front.common.partial_infer.utils import int64_array from mo.graph.graph import Node, Graph from mo.ops.op import Op, PermuteAttrs +from mo.utils.error import Error class Gather(Op): @@ -17,12 +18,10 @@ class Gather(Op): super().__init__(graph, { 'op': self.op, 'type': self.op, - 'version': 'opset1', - + 'version': 'opset7', + 'batch_dims': 0, 'infer': self.infer, - 'force_precision_in_ports': {1: 'int32', 2: 'int64'}, - 'in_ports_count': 3, 'out_ports_count': 1, }, attrs) @@ -30,6 +29,15 @@ class Gather(Op): assert 'axis' not in self.attrs, \ 'Use AttributedGather operation instead of Gather to create it with `axis` as a parameter' + def backend_attrs(self): + version = self.get_opset() + if version == 'opset7': + return ['batch_dims'] + elif version == 'opset1': + return [] + else: + raise Error('Unsupported operation opset version "{}"'.format(version)) + @staticmethod def infer(node: Node): name = node.soft_get('name', node.id) @@ -44,25 +52,46 @@ class Gather(Op): indices_shape = node.in_port(1).data.get_shape() assert indices_shape is not None axis = node.in_port(2).data.get_value() - assert axis is not None - axis = get_canonical_axis_index(data_shape, axis) + assert axis is not None, 'axis input is undefined' + + assert -len(data_shape) <= axis < len(data_shape), \ + 'axis must be within interval [-data_rank, data_rank). Instead got axis = {}, data_rank = {} '.\ + format(axis, len(data_shape)) + + batch_dims = node.batch_dims + assert -len(indices_shape) <= batch_dims <= len(indices_shape), \ + 'batch_dims must be within interval [-indices_rank, indices_rank]. Instead got batch_dims = {}, ' \ + 'indices_rank = {} '.format(batch_dims, len(indices_shape)) + + # normalize to positive values + axis = axis + len(data_shape) if axis < 0 else axis + batch_dims = batch_dims + len(indices_shape) if batch_dims < 0 else batch_dims + + assert np.array_equal(data_shape[:batch_dims], indices_shape[:batch_dims]), \ + 'data and indices inputs must have equal first dimensions until batch_dims' + + assert batch_dims <= axis, \ + 'normalized batch_dims must be <= axis. Instead got batch_dims = {}, axis = {}'.format(axis, batch_dims) # we import PermuteInputs locally because it uses Gather inside and we have recursive imports from mo.graph.perm_inputs import PermuteInputs PermuteInputs().set_input_permutation(node.in_node(1), node, 'input:0', 'axis') + batch_dims_range = indices_shape[:batch_dims] + out_shape = np.concatenate((data_shape[:axis], indices_shape[batch_dims:], data_shape[axis + 1:])) + data_value = node.in_port(0).data.get_value() indices_value = node.in_port(1).data.get_value() if data_value is not None and indices_value is not None: - node.out_port(0).data.set_value(np.array(np.take(data_value, int64_array(indices_value), axis), - dtype=data_value.dtype)) - return - - shape = np.concatenate((data_shape[:axis], indices_shape)) - if axis < len(data_shape) - 1: - shape = np.concatenate((shape, data_shape[axis + 1:])) - - node.out_port(0).data.set_shape(int64_array(shape)) + if batch_dims == 0: + node.out_port(0).data.set_value(np.take(data_value, indices_value, axis)) + else: + out_value = np.empty(out_shape) + for batch_idx in np.ndindex(tuple(batch_dims_range)): + out_value[batch_idx] = np.take(data_value[batch_idx], indices_value[batch_idx], axis - batch_dims) + node.out_port(0).data.set_value(out_value) + else: + node.out_port(0).data.set_shape(int64_array(out_shape)) class AttributedGather(Op): @@ -80,7 +109,7 @@ class AttributedGather(Op): 'force_precision_in_ports': {1: 'int32'}, - 'in_ports_count': 3, + 'in_ports_count': 2, 'out_ports_count': 1, }, attrs) diff --git a/model-optimizer/unit_tests/extensions/ops/gather_test.py b/model-optimizer/unit_tests/extensions/ops/gather_test.py index d71445936b0..4891e4d8b92 100644 --- a/model-optimizer/unit_tests/extensions/ops/gather_test.py +++ b/model-optimizer/unit_tests/extensions/ops/gather_test.py @@ -3,59 +3,204 @@ import unittest -import numpy as np +import numpy.testing as npt from extensions.ops.gather import Gather from mo.front.common.partial_infer.utils import int64_array from mo.graph.graph import Node -from unit_tests.utils.graph import build_graph +from mo.middle.passes.infer import partial_infer +from mo.utils.error import Error +from unit_tests.utils.graph import valued_const_with_data, result, regular_op_with_empty_data, connect, \ + shaped_parameter, build_graph class TestGatherPartialInfer(unittest.TestCase): + @staticmethod - def _create_graph(): - nodes_attributes = { - - 'gather_input': {'kind': 'op'}, - 'gather_input_data': {'shape': None, 'value': None, 'kind': 'data'}, - 'gather_input2': {'kind': 'op'}, - 'gather_input2_data': {'shape': None, 'value': None, 'kind': 'data'}, - 'gather_input3': {'kind': 'op'}, - 'gather_input3_data': {'shape': None, 'value': 0, 'kind': 'data'}, - - 'gather_node': {'op': 'Gather', 'kind': 'op'}, - 'gather_output': {'shape': None, 'value': None, 'kind': 'data'} - + def build_and_test_value_inference(data, indices, axis, batch_dims, ref_value, negative_test_string=None): + nodes = { + **valued_const_with_data('data', int64_array(data)), + **valued_const_with_data('indices', int64_array(indices)), + **valued_const_with_data('axis', int64_array(axis)), + **regular_op_with_empty_data('gather', {'op': 'Gather', 'batch_dims': batch_dims, 'infer': Gather.infer}), + **result('res'), } - return build_graph(nodes_attributes, - [ - ('gather_input', 'gather_input_data'), - ('gather_input2', 'gather_input2_data'), - ('gather_input3', 'gather_input3_data'), - ('gather_input_data', 'gather_node'), - ('gather_input2_data', 'gather_node'), - ('gather_input3_data', 'gather_node'), + edges = [ + *connect('data', '0:gather'), + *connect('indices', '1:gather'), + *connect('axis', '2:gather'), + *connect('gather', 'res') + ] - ('gather_node', 'gather_output'), - ], - { - 'gather_input_data': {'shape': int64_array([10, 15]), 'value': np.ones((3, 15))}, - 'gather_input2_data': {'shape': int64_array([2]), 'value': np.array([0, 2])}, - }) + graph = build_graph(nodes, edges) + graph.stage = 'middle' + partial_infer(graph) - def test_gather_infer(self): - graph = self._create_graph() + node = Node(graph, 'gather') + res = node.out_port(0).data.get_value() + npt.assert_array_equal(res, ref_value) - gather_node = Node(graph, 'gather_node') - Gather.infer(gather_node) + @staticmethod + def build_and_test_shape_inference(data_shape, indices_shape, axis, batch_dims, ref_shape): + nodes = { + **shaped_parameter('data', int64_array(data_shape)), + **shaped_parameter('indices', int64_array(indices_shape)), + **valued_const_with_data('axis', int64_array(axis)), + **regular_op_with_empty_data('gather', {'op': 'Gather', 'batch_dims': batch_dims, 'infer': Gather.infer}), + **result('res'), + } - exp_shape = int64_array([2, 15]) - res_shape = graph.node['gather_output']['shape'] - res_value = graph.node['gather_output']['value'] + edges = [ + *connect('data', '0:gather'), + *connect('indices', '1:gather'), + *connect('axis', '2:gather'), + *connect('gather', 'res') + ] - self.assertTrue(np.array_equal(exp_shape, res_shape), - 'shapes do not match expected: {} and given: {}'.format(exp_shape, res_shape)) + graph = build_graph(nodes, edges) + graph.stage = 'middle' + partial_infer(graph) - self.assertTrue(np.array_equal(res_value, np.ones(exp_shape)), - 'shapes do not match expected: {} and given: {}'.format(exp_shape, res_shape)) + node = Node(graph, 'gather') + res = node.out_port(0).data.get_shape() + npt.assert_array_equal(res, ref_shape) + + def test_shape_axis_1(self): + self.build_and_test_shape_inference(axis=1, batch_dims=0, + data_shape=[3, 3], + indices_shape=[1, 2], + ref_shape=[3, 1, 2]) + + def test_shape_axis_0(self): + self.build_and_test_shape_inference(axis=0, batch_dims=0, + data_shape=[3, 3], + indices_shape=[1, 2], + ref_shape=[1, 2, 3]) + + def test_shape_axis_minus_2(self): + self.build_and_test_shape_inference(axis=-2, batch_dims=0, + data_shape=[2, 3, 7], + indices_shape=[1, 4], + ref_shape=[2, 1, 4, 7]) + + def test_shape_axis_1_batch_dims_1(self): + self.build_and_test_shape_inference(axis=1, batch_dims=1, + data_shape=[3, 4], + indices_shape=[3, 1, 2], + ref_shape=[3, 1, 2]) + + def test_shape_axis_2_batch_dims_1(self): + self.build_and_test_shape_inference(axis=2, batch_dims=1, + data_shape=[3, 4, 7], + indices_shape=[3, 1, 2], + ref_shape=[3, 4, 1, 2]) + + def test_shape_axis_2_batch_dims_minus_1(self): + self.build_and_test_shape_inference(axis=2, batch_dims=-1, + data_shape=[3, 1, 7], + indices_shape=[3, 1, 2], + ref_shape=[3, 1, 2]) + + def test_shape_axis_2_batch_dims_minus_2(self): + self.build_and_test_shape_inference(axis=2, batch_dims=-2, + data_shape=[3, 4, 7], + indices_shape=[3, 1, 2], + ref_shape=[3, 4, 1, 2]) + + def test_axis_0_batch_dims_0(self): + self.build_and_test_value_inference(axis=0, batch_dims=0, + data=[1, 2, 3, 4, 5], + indices=[0, 0, 4], + ref_value=[1, 1, 5]) + + def test_axis_1_batch_dims_1(self): + self.build_and_test_value_inference(axis=1, batch_dims=1, + data=[[1, 2, 3, 4, 5], + [6, 7, 8, 9, 10]], + indices=[[0, 0, 4], + [4, 0, 0]], + + ref_value=[[1, 1, 5], + [10, 6, 6]]) + + def test_axis_minus_1_batch_dims_1(self): + self.build_and_test_value_inference(axis=-1, batch_dims=1, + data=[[1, 2, 3, 4, 5], + [6, 7, 8, 9, 10]], + indices=[[0, 0, 4], + [4, 0, 0]], + + ref_value=[[1, 1, 5], + [10, 6, 6]]) + + def test_axis_2_batch_dims_1(self): + self.build_and_test_value_inference(axis=2, batch_dims=1, + data=[[[[ 1, 2, 3, 4], # <-- first batch + [ 5, 6, 7, 8], + [ 9, 10, 11, 12], + [13, 14, 15, 16], + [17, 18, 19, 20]]], + [[[21, 22, 23, 24], # < -- second batch + [25, 26, 27, 28], + [29, 30, 31, 32], + [33, 34, 35, 36], + [37, 38, 39, 40]]]], # data_shape = (2, 1, 5, 4) + indices=[[1, 2, 4], + [4, 3, 2]], + ref_value=[[[[ 5, 6, 7, 8], + [ 9, 10, 11, 12], + [17, 18, 19, 20]]], + [[[37, 38, 39, 40], + [33, 34, 35, 36], + [29, 30, 31, 32]]]]) + + def test_axis_2_batch_dims_mimus_1(self): + self.build_and_test_value_inference(axis=2, batch_dims=-1, + data=[[[[ 1, 2, 3, 4], # <-- first batch + [ 5, 6, 7, 8], + [ 9, 10, 11, 12], + [13, 14, 15, 16], + [17, 18, 19, 20]]], + [[[21, 22, 23, 24], # < -- second batch + [25, 26, 27, 28], + [29, 30, 31, 32], + [33, 34, 35, 36], + [37, 38, 39, 40]]]], # data_shape = (2, 1, 5, 4) + indices=[[1, 2, 4], + [4, 3, 2]], + ref_value=[[[[ 5, 6, 7, 8], + [ 9, 10, 11, 12], + [17, 18, 19, 20]]], + [[[37, 38, 39, 40], + [33, 34, 35, 36], + [29, 30, 31, 32]]]]) + + # negative tests + def test_shape_indices_data_shape_inconsistency(self): + self.assertRaises(Error, self.build_and_test_shape_inference, + axis=2, batch_dims=2, + data_shape=[3, 4, 7], + indices_shape=[3, 1, 2], + ref_shape=[3, 4, 2]) + + def test_shape_batch_dims_greater_than_axis(self): + self.assertRaises(Error, self.build_and_test_shape_inference, + axis=2, batch_dims=3, + data_shape=[3, 4, 7], + indices_shape=[3, 4, 2], + ref_shape=[3, 4, 2]) + + def test_shape_batch_dims_out_of_bound(self): + self.assertRaises(Error, self.build_and_test_shape_inference, + axis=2, batch_dims=4, + data_shape=[3, 4, 7], + indices_shape=[3, 4, 2], + ref_shape=[3, 4, 2]) + + def test_shape_axis_out_of_bound(self): + self.assertRaises(Error, self.build_and_test_shape_inference, + axis=3, batch_dims=2, + data_shape=[3, 4, 7], + indices_shape=[3, 4, 2], + ref_shape=[3, 4, 2]) diff --git a/model-optimizer/unit_tests/utils/graph.py b/model-optimizer/unit_tests/utils/graph.py index c09abb0e208..1b349ed939a 100644 --- a/model-optimizer/unit_tests/utils/graph.py +++ b/model-optimizer/unit_tests/utils/graph.py @@ -6,6 +6,7 @@ from copy import deepcopy import networkx as nx +from extensions.ops.parameter import Parameter from mo.front.common.partial_infer.utils import int64_array from mo.graph.graph import Node, Graph from mo.middle.pattern_match import all_edges_in_nodes @@ -275,6 +276,10 @@ shaped_data = lambda name, shape: {name: {'kind': 'data', 'value': None, 'shape': int64_array(shape) if shape is not None else None}} empty_data = lambda name: valued_data(name, None) +shaped_parameter = lambda name, shape: {**regular_op(name, {'op': 'Parameter', 'shape': shape, + 'infer': Parameter.infer}), + **shaped_data(name + '_d', shape)} + result = lambda name=None: {name if name is not None else 'output': {'kind': 'op', 'type': 'Result', 'op': 'Result', 'infer': lambda x: 0}}