Implementation of ArgMin ONNX + TF extractors (#5126)

* implement argmin extractors

* reconsidering argmax to topk

* arg ops refactoring

* rename ArgMaxToTopK

* added unittests

* update docs

* move unittest file to new folder

* conversations resolving

* revert changes with argmax.py, move argmin op to a new file

* rename ArgMaxSqueeze

* updated BOM file

* little fix

* code refactoring in ArgMaxOp, updated unittests

Co-authored-by: yegor.kruglov <ykruglov@nnlvdp-mkaglins.inn.intel.com>
This commit is contained in:
Yegor Kruglov
2021-05-06 13:41:49 +03:00
committed by GitHub
parent e3ea9bf4dd
commit abb1ca657e
12 changed files with 327 additions and 91 deletions

View File

@@ -128,6 +128,7 @@ Standard TensorFlow\* operations:
| AddV2 | No |
| AddN | No |
| ArgMax | No |
| ArgMin | No |
| Asinh | No |
| Atanh | No |
| AvgPool | No |
@@ -398,6 +399,7 @@ Standard ONNX\* operators:
| Add | No |
| Affine | No |
| ArgMax | No |
| ArgMin | No |
| Asin | No |
| Asinh | No |
| Atan | No |

View File

@@ -63,7 +63,7 @@ extensions/back/TopKNormalizer.py
extensions/back/TransposeReduceFusing.py
extensions/back/UselessConcatRemoval.py
extensions/front/__init__.py
extensions/front/ArgMaxSqueeze.py
extensions/front/ArgOpsSqueeze.py
extensions/front/ATenToEmbeddingBag.py
extensions/front/AttributedClampNormalizer.py
extensions/front/AttributedGatherNormalizer.py
@@ -255,6 +255,7 @@ extensions/front/onnx/__init__.py
extensions/front/onnx/activation_ext.py
extensions/front/onnx/affine_ext.py
extensions/front/onnx/argmax_ext.py
extensions/front/onnx/argmin_ext.py
extensions/front/onnx/aten_ext.py
extensions/front/onnx/AttributedSliceToSlice.py
extensions/front/onnx/cast_ext.py
@@ -368,6 +369,7 @@ extensions/front/Swish_fusion.py
extensions/front/tf/__init__.py
extensions/front/tf/activation_ext.py
extensions/front/tf/argmax_ext.py
extensions/front/tf/argmin_ext.py
extensions/front/tf/assign_elimination.py
extensions/front/tf/automl_efficientdet.json
extensions/front/tf/AutomlEfficientDet.py
@@ -545,7 +547,7 @@ extensions/middle/AddIsCyclicAttribute.py
extensions/middle/AddMeanScaleValues.py
extensions/middle/ApplyNHWCtoNCHWpermutation.py
extensions/middle/ApplyPermutations.py
extensions/middle/ArgMaxToTopK.py
extensions/middle/ArgOpsToTopK.py
extensions/middle/AttributedTileNormalizer.py
extensions/middle/BiasAddBroadcasting.py
extensions/middle/BinarizeWeightsM1P1.py
@@ -641,6 +643,7 @@ extensions/ops/accum.py
extensions/ops/activation_ops.py
extensions/ops/adaptive_avg_pooling.py
extensions/ops/argmax.py
extensions/ops/argmin.py
extensions/ops/assert_op.py
extensions/ops/aten.py
extensions/ops/axpy.py

View File

@@ -7,22 +7,21 @@ from mo.ops.const import Const
from mo.ops.squeeze import Squeeze
class ArgMaxSqueeze(FrontReplacementSubgraph):
class ArgOpsSqueeze(FrontReplacementSubgraph):
"""
In some frameworks ArgMax operation has keepdims attribute that indicates whether to stay a dimension along
which maximum is computed or not. In case of keepdims=0 this dimension should be removed but ArgMax operation in
IR format is not designed to cover this case. So we should additionally add Squeeze operation right after ArgMax
for this case.
In some frameworks ArgMax/ArgMin operation has keepdims attribute that indicates whether to stay a dimension
along which maximum is computed or not. In case of keepdims=0 this dimension should be removed but ArgMax/ArgMin
operation in IR format is not designed to cover this case. So we should additionally add Squeeze operation right
after ArgMax/ArgMin for this case.
"""
op = "ArgMax"
enabled = True
def pattern(self):
return dict(nodes=[('argmax', dict(op='ArgMax', keepdims=0))],
return dict(nodes=[('node', dict(op=lambda x: x in ['ArgMax', 'ArgMin'], keepdims=0))],
edges=[])
def replace_sub_graph(self, graph: Graph, match: dict):
node = match['argmax']
node = match['node']
connected_ports = [port for port in node.in_ports().values() if not port.disconnected()]
squeeze_node = Squeeze(graph, dict()).create_node([], dict(name=node.name + '/Squeeze'))

View File

@@ -0,0 +1,26 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from extensions.ops.argmin import ArgMinOp
from mo.front.extractor import FrontExtractorOp
from mo.front.onnx.extractors.utils import onnx_attr
class ArgMinFrontExtractor(FrontExtractorOp):
op = 'ArgMin'
enabled = True
@classmethod
def extract(cls, node):
keepdims = onnx_attr(node, 'keepdims', 'i', default=1)
axis = onnx_attr(node, 'axis', 'i', default=0)
attrs = {
'axis': axis,
'top_k': 1,
'keepdims': keepdims,
'remove_values_output': True
}
ArgMinOp.update_node_stat(node, attrs)
return cls.enabled

View File

@@ -16,6 +16,6 @@ class ArgMaxFrontExtractor(FrontExtractorOp):
def extract(cls, node):
ArgMaxOp.update_node_stat(node, {'out_max_val': 0, 'top_k': 1, 'axis': None,
'dim_attrs': ['axis'], 'keepdims': 0, 'remove_values_output': True,
'output_type': tf_dtype_extractor(node.pb.attr['out_type'].type, np.int64),
'output_type': tf_dtype_extractor(node.pb.attr['output_type'].type, np.int64),
})
return cls.enabled

View File

@@ -0,0 +1,25 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from extensions.ops.argmin import ArgMinOp
from mo.front.extractor import FrontExtractorOp
from mo.front.tf.extractors.utils import tf_dtype_extractor
class ArgMinFrontExtractor(FrontExtractorOp):
op = 'ArgMin'
enabled = True
@classmethod
def extract(cls, node):
attrs = {
'top_k': 1,
'axis': None,
'keepdims': 0,
'remove_values_output': True,
'output_type': tf_dtype_extractor(node.pb.attr['output_type'].type, np.int64)
}
ArgMinOp.update_node_stat(node, attrs)
return cls.enabled

View File

@@ -8,24 +8,24 @@ from mo.ops.concat import Concat
from mo.ops.const import Const
class ArgMaxToTopK(MiddleReplacementPattern):
class ArgOpsToTopK(MiddleReplacementPattern):
"""
The transformation replaces ArgMax with the TopK layer.
The transformation replaces ArgMax/ArgMin with the TopK layer.
"""
op = "ArgMax"
enabled = True
force_clean_up = True
def pattern(self):
return dict(
nodes=[
('argmax', dict(op='ArgMax')),
('node', dict(op=lambda x: x in ['ArgMax', 'ArgMin'])),
],
edges=[]
)
def replace_pattern(self, graph: Graph, match: dict):
node = match['argmax']
node = match['node']
node_name = node.soft_get('name', node.id)
connected_ports = [port for port in node.in_ports().values() if not port.disconnected()]
@@ -36,7 +36,9 @@ class ArgMaxToTopK(MiddleReplacementPattern):
assert axis is not None, 'The "axis" should be defined for node "{}"'.format(node_name)
assert node.has_and_set('output_type'), 'The data type is not set for node "{}"'.format(node_name)
topk_node = TopK(graph, {'axis': axis, 'mode': 'max', 'sort': 'index',
topk_mode = 'max' if node.op == 'ArgMax' else 'min'
topk_node = TopK(graph, {'axis': axis, 'mode': topk_mode, 'sort': 'index',
'remove_values_output': node.has_and_set('remove_values_output'),
'index_element_type': node.output_type}).create_node()
node.in_port(0).get_connection().set_destination(topk_node.in_port(0))
@@ -47,7 +49,7 @@ class ArgMaxToTopK(MiddleReplacementPattern):
topk_node.out_port(0).connect(concat_node.in_port(1)) # indices
topk_node.out_port(1).connect(concat_node.in_port(0)) # values
if not node.out_port(0).disconnected():
node.out_port(0).get_connection().set_source(concat_node.out_port(1))
node.out_port(0).get_connection().set_source(concat_node.out_port(0))
else:
if not node.out_port(0).disconnected():
node.out_port(0).get_connection().set_source(topk_node.out_port(1))

View File

@@ -6,18 +6,57 @@ import logging as log
import numpy as np
from mo.front.caffe.extractors.utils import get_canonical_axis_index
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node, Graph
from mo.ops.op import Op, PermuteAttrs
def arg_ops_infer(node: Node):
shape = node.in_port(0).data.get_shape()
node_name = node.soft_get('name', node.id)
assert shape is not None, "Input shape for the node {} is None".format(node_name)
# there are two inputs in TensorFlow. The second input is the axis for ArgMax
connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
if len(connected_in_ports) == 2:
axis = node.in_port(1).data.get_value()
if axis is None:
log.debug('The second argument to {} is None'.format(node.soft_get('name', node.id)))
return
node.axis = axis
# remove the unnecessary input
node.in_port(1).disconnect()
num_top_axes = shape.size
if num_top_axes < 3:
num_top_axes = 3
out_shape = np.ones(num_top_axes, dtype=np.int64)
if node.has_valid('axis'):
axis = get_canonical_axis_index(shape, node.axis)
node.axis = axis
out_shape = int64_array(shape)
out_shape[axis] = node.top_k
PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
else:
out_shape[0] = shape[0]
out_shape[2] = node.top_k
if node.has_and_set('out_max_val'):
out_shape[1] = 2
node.out_port(0).data.set_shape(out_shape)
class ArgMaxOp(Op):
op = 'ArgMax'
enabled = False
def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'type': __class__.op,
'op': __class__.op,
'infer': ArgMaxOp.argmax_infer,
'type': None,
'op': self.op,
'infer': arg_ops_infer,
'output_type': np.int64,
'in_ports_count': 2,
'out_ports_count': 1,
@@ -30,38 +69,3 @@ class ArgMaxOp(Op):
'top_k',
'axis',
]
@staticmethod
def argmax_infer(node: Node):
shape = node.in_node(0).shape
if shape is None:
return
# there are two inputs in TensorFlow. The second input is the axis for ArgMax
if len(node.in_nodes()) == 2:
if node.in_node(1).value is None:
log.debug('The second argument to ArgMax is None')
return
node.axis = node.in_node(1).value.item()
# remove the unnecessary input
node.graph.remove_edge(node.in_node(1).id, node.id)
num_top_axes = shape.size
if num_top_axes < 3:
num_top_axes = 3
out_shape = np.ones(num_top_axes, dtype=int)
if node.has_valid('axis'):
axis = get_canonical_axis_index(shape, node.axis)
node.axis = axis
out_shape = np.array(shape)
out_shape[axis] = node.top_k
PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
else:
out_shape[0] = shape[0]
out_shape[2] = node.top_k
if node.out_max_val:
out_shape[1] = 2
node.out_node().shape = out_shape

View File

@@ -0,0 +1,30 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from extensions.ops.argmax import arg_ops_infer
from mo.graph.graph import Graph
from mo.ops.op import Op
class ArgMinOp(Op):
op = 'ArgMin'
enabled = False
def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'type': None,
'op': self.op,
'infer': arg_ops_infer,
'output_type': np.int64,
'in_ports_count': 2,
'out_ports_count': 1,
}
super().__init__(graph, mandatory_props, attrs)
def supported_attrs(self):
return [
'top_k',
'axis',
]

View File

@@ -5,7 +5,7 @@ import unittest
from unittest.mock import patch
from extensions.front.caffe.argmax_ext import ArgMaxFrontExtractor
from extensions.ops.argmax import ArgMaxOp
from extensions.ops.argmax import ArgMaxOp, arg_ops_infer
from mo.ops.op import Op
from unit_tests.utils.extractors import FakeMultiParam
from unit_tests.utils.graph import FakeNode
@@ -44,7 +44,7 @@ class TestArgMaxExt(unittest.TestCase):
'out_max_val': True,
'top_k': 100,
'axis': 2,
'infer': ArgMaxOp.argmax_infer,
'infer': arg_ops_infer,
'remove_values_output': True,
}

View File

@@ -0,0 +1,152 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
import numpy as np
from extensions.middle.ArgOpsToTopK import ArgOpsToTopK
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import regular_op_with_empty_data, result, build_graph, connect, \
valued_const_with_data, regular_op, empty_data, connect_front
nodes_attributes = {
**regular_op_with_empty_data('input', {'op': 'Parameter', 'type': 'Parameter'}),
**regular_op_with_empty_data('argmax', {'op': 'ArgMax', 'type': None, 'out_max_val': 0, 'top_k': 1, 'axis': 0,
'output_type': np.int32, 'remove_values_output': True}),
**regular_op_with_empty_data('argmin', {'op': 'ArgMin', 'type': None, 'top_k': 1, 'axis': 0,
'output_type': np.int32, 'remove_values_output': True}),
**result('result'),
**valued_const_with_data('axis_const', int64_array([1])),
**regular_op('topk', {'op': 'TopK', 'type': 'TopK', 'sort': 'index', 'index_element_type': np.int32}),
**empty_data('topk_out_0_data'),
**empty_data('topk_out_1_data'),
**regular_op_with_empty_data('topk_scalar', {'op': 'Const', 'type': 'Const', 'value': int64_array([1]),
'shape': []}),
**regular_op_with_empty_data('concat', {'op': 'Concat', 'type': 'Concat', 'axis': 1})
}
class ArgOpsToTopKTest(unittest.TestCase):
def test_tf_argmax_to_topk(self):
graph = build_graph(nodes_attrs=nodes_attributes,
edges=[
*connect('input', '0:argmax'),
*connect('axis_const', '1:argmax'),
*connect('argmax', 'result')
],
nodes_with_edges_only=True)
ArgOpsToTopK().find_and_replace_pattern(graph)
ref_graph = build_graph(nodes_attrs=nodes_attributes,
edges=[
*connect('input', '0:topk'),
*connect('topk_scalar', '1:topk'),
*connect_front('topk:1', 'topk_out_1_data'),
*connect_front('topk_out_1_data', 'result'),
],
update_attributes={
'topk': {'axis': int64_array([1]), 'mode': 'max', 'remove_values_output': True},
},
nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, ref_graph, 'input', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_tf_argmin_to_topk(self):
graph = build_graph(nodes_attrs=nodes_attributes,
edges=[
*connect('input', '0:argmin'),
*connect('axis_const', '1:argmin'),
*connect('argmin', 'result')
],
nodes_with_edges_only=True)
ArgOpsToTopK().find_and_replace_pattern(graph)
ref_graph = build_graph(nodes_attrs=nodes_attributes,
edges=[
*connect('input', '0:topk'),
*connect('topk_scalar', '1:topk'),
*connect_front('topk:1', 'topk_out_1_data'),
*connect_front('topk_out_1_data', 'result')
],
update_attributes={
'topk': {'axis': int64_array([1]), 'mode': 'min', 'remove_values_output': True},
},
nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, ref_graph, 'input', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_onnx_argmax_to_topk(self):
graph = build_graph(nodes_attrs=nodes_attributes,
edges=[
*connect('input', 'argmax'),
*connect('argmax', 'result')
],
nodes_with_edges_only=True)
ArgOpsToTopK().find_and_replace_pattern(graph)
ref_graph = build_graph(nodes_attrs=nodes_attributes,
edges=[
*connect('input', '0:topk'),
*connect('topk_scalar', '1:topk'),
*connect_front('topk:1', 'topk_out_1_data'),
*connect_front('topk_out_1_data', 'result')
],
update_attributes={
'topk': {'axis': 0, 'mode': 'max', 'remove_values_output': True},
},
nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, ref_graph, 'input', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_onnx_argmin_to_topk(self):
graph = build_graph(nodes_attrs=nodes_attributes,
edges=[
*connect('input', 'argmin'),
*connect('argmin', 'result')
],
nodes_with_edges_only=True)
ArgOpsToTopK().find_and_replace_pattern(graph)
ref_graph = build_graph(nodes_attrs=nodes_attributes,
edges=[
*connect('input', '0:topk'),
*connect('topk_scalar', '1:topk'),
*connect_front('topk:1', 'topk_out_1_data'),
*connect_front('topk_out_1_data', 'result')
],
update_attributes={
'topk': {'axis': 0, 'mode': 'min', 'remove_values_output': True},
},
nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, ref_graph, 'input', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_caffe_argmax_to_topk(self):
graph = build_graph(nodes_attrs=nodes_attributes,
edges=[
*connect('input', 'argmax'),
*connect('argmax', 'result')
],
update_attributes={
'argmax': {'out_max_val': 1}
},
nodes_with_edges_only=True)
ArgOpsToTopK().find_and_replace_pattern(graph)
ref_graph = build_graph(nodes_attrs=nodes_attributes,
edges=[
*connect('input', '0:topk'),
*connect('topk_scalar', '1:topk'),
*connect_front('topk:0','topk_out_0_data'),
*connect_front('topk:1', 'topk_out_1_data'),
*connect_front('topk_out_0_data', '1:concat'),
*connect_front('topk_out_1_data', '0:concat'),
*connect('concat', 'result')
],
update_attributes={
'topk': {'axis': 0, 'mode': 'max', 'remove_values_output': True},
},
nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, ref_graph, 'input', check_op_attrs=True)
self.assertTrue(flag, resp)

View File

@@ -5,21 +5,25 @@ import unittest
import numpy as np
from extensions.ops.argmax import ArgMaxOp
from extensions.ops.argmax import arg_ops_infer
from mo.graph.graph import Node
from unit_tests.utils.graph import build_graph
nodes_attributes = {'node_1': {'type': 'Identity', 'kind': 'op'},
nodes_attributes = {
'op_input': {'kind': 'op', 'op': 'Parameter'},
'node_1': {'kind': 'data'},
'argmax': {'op': 'ArgMax', 'kind': 'op'},
'node_3': {'type': 'Identity', 'kind': 'op'},
'op_output': { 'kind': 'op', 'op': 'Result'}
'node_3': {'kind': 'data', 'value': None},
'op_output': {'kind': 'op', 'op': 'Result'}
}
class TestArgMaxOp(unittest.TestCase):
def test_caffe_argmax_axis(self):
graph = build_graph(nodes_attributes,
[('node_1', 'argmax'),
[
('op_input', 'node_1'),
('node_1', 'argmax'),
('argmax', 'node_3'),
('node_3', 'op_output')
],
@@ -33,7 +37,7 @@ class TestArgMaxOp(unittest.TestCase):
})
argmax_node = Node(graph, 'argmax')
ArgMaxOp.argmax_infer(argmax_node)
arg_ops_infer(argmax_node)
exp_shape = np.array([1, 3, 100, 2049])
res_shape = graph.node['node_3']['shape']
for i in range(0, len(exp_shape)):
@@ -41,7 +45,9 @@ class TestArgMaxOp(unittest.TestCase):
def test_caffe_argmax_axis_negative(self):
graph = build_graph(nodes_attributes,
[('node_1', 'argmax'),
[
('op_input', 'node_1'),
('node_1', 'argmax'),
('argmax', 'node_3'),
('node_3', 'op_output')
],
@@ -55,7 +61,7 @@ class TestArgMaxOp(unittest.TestCase):
})
argmax_node = Node(graph, 'argmax')
ArgMaxOp.argmax_infer(argmax_node)
arg_ops_infer(argmax_node)
exp_shape = np.array([1, 3, 1025, 100])
res_shape = graph.node['node_3']['shape']
self.assertEqual(argmax_node.axis, 3)
@@ -64,7 +70,9 @@ class TestArgMaxOp(unittest.TestCase):
def test_caffe_argmax_no_axis(self):
graph = build_graph(nodes_attributes,
[('node_1', 'argmax'),
[
('op_input', 'node_1'),
('node_1', 'argmax'),
('argmax', 'node_3'),
('node_3', 'op_output')
],
@@ -77,7 +85,7 @@ class TestArgMaxOp(unittest.TestCase):
})
argmax_node = Node(graph, 'argmax')
ArgMaxOp.argmax_infer(argmax_node)
arg_ops_infer(argmax_node)
exp_shape = np.array([1, 2, 100, 1])
res_shape = graph.node['node_3']['shape']
for i in range(0, len(exp_shape)):
@@ -85,7 +93,9 @@ class TestArgMaxOp(unittest.TestCase):
def test_caffe_argmax_extend_shape(self):
graph = build_graph(nodes_attributes,
[('node_1', 'argmax'),
[
('op_input', 'node_1'),
('node_1', 'argmax'),
('argmax', 'node_3'),
('node_3', 'op_output')
],
@@ -98,7 +108,7 @@ class TestArgMaxOp(unittest.TestCase):
})
argmax_node = Node(graph, 'argmax')
ArgMaxOp.argmax_infer(argmax_node)
arg_ops_infer(argmax_node)
exp_shape = np.array([1, 2, 100])
res_shape = graph.node['node_3']['shape']
for i in range(0, len(exp_shape)):
@@ -106,7 +116,9 @@ class TestArgMaxOp(unittest.TestCase):
def test_caffe_argmax_out_max_val_false(self):
graph = build_graph(nodes_attributes,
[('node_1', 'argmax'),
[
('op_input', 'node_1'),
('node_1', 'argmax'),
('argmax', 'node_3'),
('node_3', 'op_output')
],
@@ -119,27 +131,8 @@ class TestArgMaxOp(unittest.TestCase):
})
argmax_node = Node(graph, 'argmax')
ArgMaxOp.argmax_infer(argmax_node)
arg_ops_infer(argmax_node)
exp_shape = np.array([1, 1, 100])
res_shape = graph.node['node_3']['shape']
for i in range(0, len(exp_shape)):
self.assertEqual(exp_shape[i], res_shape[i])
def test_caffe_argmax_no_shape(self):
graph = build_graph(nodes_attributes,
[('node_1', 'argmax'),
('argmax', 'node_3'),
('node_3', 'op_output')
],
{'node_3': {'shape': None},
'node_1': {'shape': None},
'argmax': {
'out_max_val': False,
'top_k': 100
}
})
argmax_node = Node(graph, 'argmax')
ArgMaxOp.argmax_infer(argmax_node)
res_shape = graph.node['node_3']['shape']
self.assertIsNone(res_shape)