Implementation of ArgMin ONNX + TF extractors (#5126)
* implement argmin extractors * reconsidering argmax to topk * arg ops refactoring * rename ArgMaxToTopK * added unittests * update docs * move unittest file to new folder * conversations resolving * revert changes with argmax.py, move argmin op to a new file * rename ArgMaxSqueeze * updated BOM file * little fix * code refactoring in ArgMaxOp, updated unittests Co-authored-by: yegor.kruglov <ykruglov@nnlvdp-mkaglins.inn.intel.com>
This commit is contained in:
@@ -128,6 +128,7 @@ Standard TensorFlow\* operations:
|
||||
| AddV2 | No |
|
||||
| AddN | No |
|
||||
| ArgMax | No |
|
||||
| ArgMin | No |
|
||||
| Asinh | No |
|
||||
| Atanh | No |
|
||||
| AvgPool | No |
|
||||
@@ -398,6 +399,7 @@ Standard ONNX\* operators:
|
||||
| Add | No |
|
||||
| Affine | No |
|
||||
| ArgMax | No |
|
||||
| ArgMin | No |
|
||||
| Asin | No |
|
||||
| Asinh | No |
|
||||
| Atan | No |
|
||||
|
||||
@@ -63,7 +63,7 @@ extensions/back/TopKNormalizer.py
|
||||
extensions/back/TransposeReduceFusing.py
|
||||
extensions/back/UselessConcatRemoval.py
|
||||
extensions/front/__init__.py
|
||||
extensions/front/ArgMaxSqueeze.py
|
||||
extensions/front/ArgOpsSqueeze.py
|
||||
extensions/front/ATenToEmbeddingBag.py
|
||||
extensions/front/AttributedClampNormalizer.py
|
||||
extensions/front/AttributedGatherNormalizer.py
|
||||
@@ -255,6 +255,7 @@ extensions/front/onnx/__init__.py
|
||||
extensions/front/onnx/activation_ext.py
|
||||
extensions/front/onnx/affine_ext.py
|
||||
extensions/front/onnx/argmax_ext.py
|
||||
extensions/front/onnx/argmin_ext.py
|
||||
extensions/front/onnx/aten_ext.py
|
||||
extensions/front/onnx/AttributedSliceToSlice.py
|
||||
extensions/front/onnx/cast_ext.py
|
||||
@@ -368,6 +369,7 @@ extensions/front/Swish_fusion.py
|
||||
extensions/front/tf/__init__.py
|
||||
extensions/front/tf/activation_ext.py
|
||||
extensions/front/tf/argmax_ext.py
|
||||
extensions/front/tf/argmin_ext.py
|
||||
extensions/front/tf/assign_elimination.py
|
||||
extensions/front/tf/automl_efficientdet.json
|
||||
extensions/front/tf/AutomlEfficientDet.py
|
||||
@@ -545,7 +547,7 @@ extensions/middle/AddIsCyclicAttribute.py
|
||||
extensions/middle/AddMeanScaleValues.py
|
||||
extensions/middle/ApplyNHWCtoNCHWpermutation.py
|
||||
extensions/middle/ApplyPermutations.py
|
||||
extensions/middle/ArgMaxToTopK.py
|
||||
extensions/middle/ArgOpsToTopK.py
|
||||
extensions/middle/AttributedTileNormalizer.py
|
||||
extensions/middle/BiasAddBroadcasting.py
|
||||
extensions/middle/BinarizeWeightsM1P1.py
|
||||
@@ -641,6 +643,7 @@ extensions/ops/accum.py
|
||||
extensions/ops/activation_ops.py
|
||||
extensions/ops/adaptive_avg_pooling.py
|
||||
extensions/ops/argmax.py
|
||||
extensions/ops/argmin.py
|
||||
extensions/ops/assert_op.py
|
||||
extensions/ops/aten.py
|
||||
extensions/ops/axpy.py
|
||||
|
||||
@@ -7,22 +7,21 @@ from mo.ops.const import Const
|
||||
from mo.ops.squeeze import Squeeze
|
||||
|
||||
|
||||
class ArgMaxSqueeze(FrontReplacementSubgraph):
|
||||
class ArgOpsSqueeze(FrontReplacementSubgraph):
|
||||
"""
|
||||
In some frameworks ArgMax operation has keepdims attribute that indicates whether to stay a dimension along
|
||||
which maximum is computed or not. In case of keepdims=0 this dimension should be removed but ArgMax operation in
|
||||
IR format is not designed to cover this case. So we should additionally add Squeeze operation right after ArgMax
|
||||
for this case.
|
||||
In some frameworks ArgMax/ArgMin operation has keepdims attribute that indicates whether to stay a dimension
|
||||
along which maximum is computed or not. In case of keepdims=0 this dimension should be removed but ArgMax/ArgMin
|
||||
operation in IR format is not designed to cover this case. So we should additionally add Squeeze operation right
|
||||
after ArgMax/ArgMin for this case.
|
||||
"""
|
||||
op = "ArgMax"
|
||||
enabled = True
|
||||
|
||||
def pattern(self):
|
||||
return dict(nodes=[('argmax', dict(op='ArgMax', keepdims=0))],
|
||||
return dict(nodes=[('node', dict(op=lambda x: x in ['ArgMax', 'ArgMin'], keepdims=0))],
|
||||
edges=[])
|
||||
|
||||
def replace_sub_graph(self, graph: Graph, match: dict):
|
||||
node = match['argmax']
|
||||
node = match['node']
|
||||
|
||||
connected_ports = [port for port in node.in_ports().values() if not port.disconnected()]
|
||||
squeeze_node = Squeeze(graph, dict()).create_node([], dict(name=node.name + '/Squeeze'))
|
||||
26
model-optimizer/extensions/front/onnx/argmin_ext.py
Normal file
26
model-optimizer/extensions/front/onnx/argmin_ext.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from extensions.ops.argmin import ArgMinOp
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.front.onnx.extractors.utils import onnx_attr
|
||||
|
||||
|
||||
class ArgMinFrontExtractor(FrontExtractorOp):
|
||||
op = 'ArgMin'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
keepdims = onnx_attr(node, 'keepdims', 'i', default=1)
|
||||
axis = onnx_attr(node, 'axis', 'i', default=0)
|
||||
|
||||
attrs = {
|
||||
'axis': axis,
|
||||
'top_k': 1,
|
||||
'keepdims': keepdims,
|
||||
'remove_values_output': True
|
||||
}
|
||||
|
||||
ArgMinOp.update_node_stat(node, attrs)
|
||||
return cls.enabled
|
||||
@@ -16,6 +16,6 @@ class ArgMaxFrontExtractor(FrontExtractorOp):
|
||||
def extract(cls, node):
|
||||
ArgMaxOp.update_node_stat(node, {'out_max_val': 0, 'top_k': 1, 'axis': None,
|
||||
'dim_attrs': ['axis'], 'keepdims': 0, 'remove_values_output': True,
|
||||
'output_type': tf_dtype_extractor(node.pb.attr['out_type'].type, np.int64),
|
||||
'output_type': tf_dtype_extractor(node.pb.attr['output_type'].type, np.int64),
|
||||
})
|
||||
return cls.enabled
|
||||
|
||||
25
model-optimizer/extensions/front/tf/argmin_ext.py
Normal file
25
model-optimizer/extensions/front/tf/argmin_ext.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.argmin import ArgMinOp
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.front.tf.extractors.utils import tf_dtype_extractor
|
||||
|
||||
|
||||
class ArgMinFrontExtractor(FrontExtractorOp):
|
||||
op = 'ArgMin'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
attrs = {
|
||||
'top_k': 1,
|
||||
'axis': None,
|
||||
'keepdims': 0,
|
||||
'remove_values_output': True,
|
||||
'output_type': tf_dtype_extractor(node.pb.attr['output_type'].type, np.int64)
|
||||
}
|
||||
ArgMinOp.update_node_stat(node, attrs)
|
||||
return cls.enabled
|
||||
@@ -8,24 +8,24 @@ from mo.ops.concat import Concat
|
||||
from mo.ops.const import Const
|
||||
|
||||
|
||||
class ArgMaxToTopK(MiddleReplacementPattern):
|
||||
class ArgOpsToTopK(MiddleReplacementPattern):
|
||||
"""
|
||||
The transformation replaces ArgMax with the TopK layer.
|
||||
The transformation replaces ArgMax/ArgMin with the TopK layer.
|
||||
"""
|
||||
op = "ArgMax"
|
||||
|
||||
enabled = True
|
||||
force_clean_up = True
|
||||
|
||||
def pattern(self):
|
||||
return dict(
|
||||
nodes=[
|
||||
('argmax', dict(op='ArgMax')),
|
||||
('node', dict(op=lambda x: x in ['ArgMax', 'ArgMin'])),
|
||||
],
|
||||
edges=[]
|
||||
)
|
||||
|
||||
def replace_pattern(self, graph: Graph, match: dict):
|
||||
node = match['argmax']
|
||||
node = match['node']
|
||||
node_name = node.soft_get('name', node.id)
|
||||
|
||||
connected_ports = [port for port in node.in_ports().values() if not port.disconnected()]
|
||||
@@ -36,7 +36,9 @@ class ArgMaxToTopK(MiddleReplacementPattern):
|
||||
|
||||
assert axis is not None, 'The "axis" should be defined for node "{}"'.format(node_name)
|
||||
assert node.has_and_set('output_type'), 'The data type is not set for node "{}"'.format(node_name)
|
||||
topk_node = TopK(graph, {'axis': axis, 'mode': 'max', 'sort': 'index',
|
||||
|
||||
topk_mode = 'max' if node.op == 'ArgMax' else 'min'
|
||||
topk_node = TopK(graph, {'axis': axis, 'mode': topk_mode, 'sort': 'index',
|
||||
'remove_values_output': node.has_and_set('remove_values_output'),
|
||||
'index_element_type': node.output_type}).create_node()
|
||||
node.in_port(0).get_connection().set_destination(topk_node.in_port(0))
|
||||
@@ -47,7 +49,7 @@ class ArgMaxToTopK(MiddleReplacementPattern):
|
||||
topk_node.out_port(0).connect(concat_node.in_port(1)) # indices
|
||||
topk_node.out_port(1).connect(concat_node.in_port(0)) # values
|
||||
if not node.out_port(0).disconnected():
|
||||
node.out_port(0).get_connection().set_source(concat_node.out_port(1))
|
||||
node.out_port(0).get_connection().set_source(concat_node.out_port(0))
|
||||
else:
|
||||
if not node.out_port(0).disconnected():
|
||||
node.out_port(0).get_connection().set_source(topk_node.out_port(1))
|
||||
@@ -6,18 +6,57 @@ import logging as log
|
||||
import numpy as np
|
||||
|
||||
from mo.front.caffe.extractors.utils import get_canonical_axis_index
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.ops.op import Op, PermuteAttrs
|
||||
|
||||
|
||||
def arg_ops_infer(node: Node):
|
||||
shape = node.in_port(0).data.get_shape()
|
||||
node_name = node.soft_get('name', node.id)
|
||||
assert shape is not None, "Input shape for the node {} is None".format(node_name)
|
||||
|
||||
# there are two inputs in TensorFlow. The second input is the axis for ArgMax
|
||||
connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
|
||||
if len(connected_in_ports) == 2:
|
||||
axis = node.in_port(1).data.get_value()
|
||||
if axis is None:
|
||||
log.debug('The second argument to {} is None'.format(node.soft_get('name', node.id)))
|
||||
return
|
||||
node.axis = axis
|
||||
# remove the unnecessary input
|
||||
node.in_port(1).disconnect()
|
||||
|
||||
num_top_axes = shape.size
|
||||
if num_top_axes < 3:
|
||||
num_top_axes = 3
|
||||
|
||||
out_shape = np.ones(num_top_axes, dtype=np.int64)
|
||||
|
||||
if node.has_valid('axis'):
|
||||
axis = get_canonical_axis_index(shape, node.axis)
|
||||
node.axis = axis
|
||||
out_shape = int64_array(shape)
|
||||
out_shape[axis] = node.top_k
|
||||
PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
|
||||
else:
|
||||
out_shape[0] = shape[0]
|
||||
out_shape[2] = node.top_k
|
||||
if node.has_and_set('out_max_val'):
|
||||
out_shape[1] = 2
|
||||
|
||||
node.out_port(0).data.set_shape(out_shape)
|
||||
|
||||
|
||||
class ArgMaxOp(Op):
|
||||
op = 'ArgMax'
|
||||
enabled = False
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
mandatory_props = {
|
||||
'type': __class__.op,
|
||||
'op': __class__.op,
|
||||
'infer': ArgMaxOp.argmax_infer,
|
||||
'type': None,
|
||||
'op': self.op,
|
||||
'infer': arg_ops_infer,
|
||||
'output_type': np.int64,
|
||||
'in_ports_count': 2,
|
||||
'out_ports_count': 1,
|
||||
@@ -30,38 +69,3 @@ class ArgMaxOp(Op):
|
||||
'top_k',
|
||||
'axis',
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def argmax_infer(node: Node):
|
||||
shape = node.in_node(0).shape
|
||||
if shape is None:
|
||||
return
|
||||
|
||||
# there are two inputs in TensorFlow. The second input is the axis for ArgMax
|
||||
if len(node.in_nodes()) == 2:
|
||||
if node.in_node(1).value is None:
|
||||
log.debug('The second argument to ArgMax is None')
|
||||
return
|
||||
node.axis = node.in_node(1).value.item()
|
||||
# remove the unnecessary input
|
||||
node.graph.remove_edge(node.in_node(1).id, node.id)
|
||||
|
||||
num_top_axes = shape.size
|
||||
if num_top_axes < 3:
|
||||
num_top_axes = 3
|
||||
|
||||
out_shape = np.ones(num_top_axes, dtype=int)
|
||||
|
||||
if node.has_valid('axis'):
|
||||
axis = get_canonical_axis_index(shape, node.axis)
|
||||
node.axis = axis
|
||||
out_shape = np.array(shape)
|
||||
out_shape[axis] = node.top_k
|
||||
PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
|
||||
else:
|
||||
out_shape[0] = shape[0]
|
||||
out_shape[2] = node.top_k
|
||||
if node.out_max_val:
|
||||
out_shape[1] = 2
|
||||
|
||||
node.out_node().shape = out_shape
|
||||
|
||||
30
model-optimizer/extensions/ops/argmin.py
Normal file
30
model-optimizer/extensions/ops/argmin.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.argmax import arg_ops_infer
|
||||
from mo.graph.graph import Graph
|
||||
from mo.ops.op import Op
|
||||
|
||||
|
||||
class ArgMinOp(Op):
|
||||
op = 'ArgMin'
|
||||
enabled = False
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
mandatory_props = {
|
||||
'type': None,
|
||||
'op': self.op,
|
||||
'infer': arg_ops_infer,
|
||||
'output_type': np.int64,
|
||||
'in_ports_count': 2,
|
||||
'out_ports_count': 1,
|
||||
}
|
||||
super().__init__(graph, mandatory_props, attrs)
|
||||
|
||||
def supported_attrs(self):
|
||||
return [
|
||||
'top_k',
|
||||
'axis',
|
||||
]
|
||||
@@ -5,7 +5,7 @@ import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from extensions.front.caffe.argmax_ext import ArgMaxFrontExtractor
|
||||
from extensions.ops.argmax import ArgMaxOp
|
||||
from extensions.ops.argmax import ArgMaxOp, arg_ops_infer
|
||||
from mo.ops.op import Op
|
||||
from unit_tests.utils.extractors import FakeMultiParam
|
||||
from unit_tests.utils.graph import FakeNode
|
||||
@@ -44,7 +44,7 @@ class TestArgMaxExt(unittest.TestCase):
|
||||
'out_max_val': True,
|
||||
'top_k': 100,
|
||||
'axis': 2,
|
||||
'infer': ArgMaxOp.argmax_infer,
|
||||
'infer': arg_ops_infer,
|
||||
'remove_values_output': True,
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.middle.ArgOpsToTopK import ArgOpsToTopK
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import regular_op_with_empty_data, result, build_graph, connect, \
|
||||
valued_const_with_data, regular_op, empty_data, connect_front
|
||||
|
||||
nodes_attributes = {
|
||||
**regular_op_with_empty_data('input', {'op': 'Parameter', 'type': 'Parameter'}),
|
||||
**regular_op_with_empty_data('argmax', {'op': 'ArgMax', 'type': None, 'out_max_val': 0, 'top_k': 1, 'axis': 0,
|
||||
'output_type': np.int32, 'remove_values_output': True}),
|
||||
**regular_op_with_empty_data('argmin', {'op': 'ArgMin', 'type': None, 'top_k': 1, 'axis': 0,
|
||||
'output_type': np.int32, 'remove_values_output': True}),
|
||||
**result('result'),
|
||||
**valued_const_with_data('axis_const', int64_array([1])),
|
||||
|
||||
**regular_op('topk', {'op': 'TopK', 'type': 'TopK', 'sort': 'index', 'index_element_type': np.int32}),
|
||||
**empty_data('topk_out_0_data'),
|
||||
**empty_data('topk_out_1_data'),
|
||||
**regular_op_with_empty_data('topk_scalar', {'op': 'Const', 'type': 'Const', 'value': int64_array([1]),
|
||||
'shape': []}),
|
||||
|
||||
|
||||
**regular_op_with_empty_data('concat', {'op': 'Concat', 'type': 'Concat', 'axis': 1})
|
||||
}
|
||||
|
||||
|
||||
class ArgOpsToTopKTest(unittest.TestCase):
|
||||
|
||||
def test_tf_argmax_to_topk(self):
|
||||
graph = build_graph(nodes_attrs=nodes_attributes,
|
||||
edges=[
|
||||
*connect('input', '0:argmax'),
|
||||
*connect('axis_const', '1:argmax'),
|
||||
*connect('argmax', 'result')
|
||||
],
|
||||
nodes_with_edges_only=True)
|
||||
ArgOpsToTopK().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph(nodes_attrs=nodes_attributes,
|
||||
edges=[
|
||||
*connect('input', '0:topk'),
|
||||
*connect('topk_scalar', '1:topk'),
|
||||
*connect_front('topk:1', 'topk_out_1_data'),
|
||||
*connect_front('topk_out_1_data', 'result'),
|
||||
],
|
||||
update_attributes={
|
||||
'topk': {'axis': int64_array([1]), 'mode': 'max', 'remove_values_output': True},
|
||||
},
|
||||
nodes_with_edges_only=True)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'input', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_tf_argmin_to_topk(self):
|
||||
graph = build_graph(nodes_attrs=nodes_attributes,
|
||||
edges=[
|
||||
*connect('input', '0:argmin'),
|
||||
*connect('axis_const', '1:argmin'),
|
||||
*connect('argmin', 'result')
|
||||
],
|
||||
nodes_with_edges_only=True)
|
||||
ArgOpsToTopK().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph(nodes_attrs=nodes_attributes,
|
||||
edges=[
|
||||
*connect('input', '0:topk'),
|
||||
*connect('topk_scalar', '1:topk'),
|
||||
*connect_front('topk:1', 'topk_out_1_data'),
|
||||
*connect_front('topk_out_1_data', 'result')
|
||||
],
|
||||
update_attributes={
|
||||
'topk': {'axis': int64_array([1]), 'mode': 'min', 'remove_values_output': True},
|
||||
},
|
||||
nodes_with_edges_only=True)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'input', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_onnx_argmax_to_topk(self):
|
||||
graph = build_graph(nodes_attrs=nodes_attributes,
|
||||
edges=[
|
||||
*connect('input', 'argmax'),
|
||||
*connect('argmax', 'result')
|
||||
],
|
||||
nodes_with_edges_only=True)
|
||||
ArgOpsToTopK().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph(nodes_attrs=nodes_attributes,
|
||||
edges=[
|
||||
*connect('input', '0:topk'),
|
||||
*connect('topk_scalar', '1:topk'),
|
||||
*connect_front('topk:1', 'topk_out_1_data'),
|
||||
*connect_front('topk_out_1_data', 'result')
|
||||
],
|
||||
update_attributes={
|
||||
'topk': {'axis': 0, 'mode': 'max', 'remove_values_output': True},
|
||||
},
|
||||
nodes_with_edges_only=True)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'input', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_onnx_argmin_to_topk(self):
|
||||
graph = build_graph(nodes_attrs=nodes_attributes,
|
||||
edges=[
|
||||
*connect('input', 'argmin'),
|
||||
*connect('argmin', 'result')
|
||||
],
|
||||
nodes_with_edges_only=True)
|
||||
ArgOpsToTopK().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph(nodes_attrs=nodes_attributes,
|
||||
edges=[
|
||||
*connect('input', '0:topk'),
|
||||
*connect('topk_scalar', '1:topk'),
|
||||
*connect_front('topk:1', 'topk_out_1_data'),
|
||||
*connect_front('topk_out_1_data', 'result')
|
||||
],
|
||||
update_attributes={
|
||||
'topk': {'axis': 0, 'mode': 'min', 'remove_values_output': True},
|
||||
},
|
||||
nodes_with_edges_only=True)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'input', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_caffe_argmax_to_topk(self):
|
||||
graph = build_graph(nodes_attrs=nodes_attributes,
|
||||
edges=[
|
||||
*connect('input', 'argmax'),
|
||||
*connect('argmax', 'result')
|
||||
],
|
||||
update_attributes={
|
||||
'argmax': {'out_max_val': 1}
|
||||
},
|
||||
nodes_with_edges_only=True)
|
||||
ArgOpsToTopK().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph(nodes_attrs=nodes_attributes,
|
||||
edges=[
|
||||
*connect('input', '0:topk'),
|
||||
*connect('topk_scalar', '1:topk'),
|
||||
*connect_front('topk:0','topk_out_0_data'),
|
||||
*connect_front('topk:1', 'topk_out_1_data'),
|
||||
*connect_front('topk_out_0_data', '1:concat'),
|
||||
*connect_front('topk_out_1_data', '0:concat'),
|
||||
*connect('concat', 'result')
|
||||
],
|
||||
update_attributes={
|
||||
'topk': {'axis': 0, 'mode': 'max', 'remove_values_output': True},
|
||||
},
|
||||
nodes_with_edges_only=True)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'input', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
@@ -5,21 +5,25 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.argmax import ArgMaxOp
|
||||
from extensions.ops.argmax import arg_ops_infer
|
||||
from mo.graph.graph import Node
|
||||
from unit_tests.utils.graph import build_graph
|
||||
|
||||
nodes_attributes = {'node_1': {'type': 'Identity', 'kind': 'op'},
|
||||
nodes_attributes = {
|
||||
'op_input': {'kind': 'op', 'op': 'Parameter'},
|
||||
'node_1': {'kind': 'data'},
|
||||
'argmax': {'op': 'ArgMax', 'kind': 'op'},
|
||||
'node_3': {'type': 'Identity', 'kind': 'op'},
|
||||
'op_output': { 'kind': 'op', 'op': 'Result'}
|
||||
'node_3': {'kind': 'data', 'value': None},
|
||||
'op_output': {'kind': 'op', 'op': 'Result'}
|
||||
}
|
||||
|
||||
|
||||
class TestArgMaxOp(unittest.TestCase):
|
||||
def test_caffe_argmax_axis(self):
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('node_1', 'argmax'),
|
||||
[
|
||||
('op_input', 'node_1'),
|
||||
('node_1', 'argmax'),
|
||||
('argmax', 'node_3'),
|
||||
('node_3', 'op_output')
|
||||
],
|
||||
@@ -33,7 +37,7 @@ class TestArgMaxOp(unittest.TestCase):
|
||||
})
|
||||
|
||||
argmax_node = Node(graph, 'argmax')
|
||||
ArgMaxOp.argmax_infer(argmax_node)
|
||||
arg_ops_infer(argmax_node)
|
||||
exp_shape = np.array([1, 3, 100, 2049])
|
||||
res_shape = graph.node['node_3']['shape']
|
||||
for i in range(0, len(exp_shape)):
|
||||
@@ -41,7 +45,9 @@ class TestArgMaxOp(unittest.TestCase):
|
||||
|
||||
def test_caffe_argmax_axis_negative(self):
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('node_1', 'argmax'),
|
||||
[
|
||||
('op_input', 'node_1'),
|
||||
('node_1', 'argmax'),
|
||||
('argmax', 'node_3'),
|
||||
('node_3', 'op_output')
|
||||
],
|
||||
@@ -55,7 +61,7 @@ class TestArgMaxOp(unittest.TestCase):
|
||||
})
|
||||
|
||||
argmax_node = Node(graph, 'argmax')
|
||||
ArgMaxOp.argmax_infer(argmax_node)
|
||||
arg_ops_infer(argmax_node)
|
||||
exp_shape = np.array([1, 3, 1025, 100])
|
||||
res_shape = graph.node['node_3']['shape']
|
||||
self.assertEqual(argmax_node.axis, 3)
|
||||
@@ -64,7 +70,9 @@ class TestArgMaxOp(unittest.TestCase):
|
||||
|
||||
def test_caffe_argmax_no_axis(self):
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('node_1', 'argmax'),
|
||||
[
|
||||
('op_input', 'node_1'),
|
||||
('node_1', 'argmax'),
|
||||
('argmax', 'node_3'),
|
||||
('node_3', 'op_output')
|
||||
],
|
||||
@@ -77,7 +85,7 @@ class TestArgMaxOp(unittest.TestCase):
|
||||
})
|
||||
|
||||
argmax_node = Node(graph, 'argmax')
|
||||
ArgMaxOp.argmax_infer(argmax_node)
|
||||
arg_ops_infer(argmax_node)
|
||||
exp_shape = np.array([1, 2, 100, 1])
|
||||
res_shape = graph.node['node_3']['shape']
|
||||
for i in range(0, len(exp_shape)):
|
||||
@@ -85,7 +93,9 @@ class TestArgMaxOp(unittest.TestCase):
|
||||
|
||||
def test_caffe_argmax_extend_shape(self):
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('node_1', 'argmax'),
|
||||
[
|
||||
('op_input', 'node_1'),
|
||||
('node_1', 'argmax'),
|
||||
('argmax', 'node_3'),
|
||||
('node_3', 'op_output')
|
||||
],
|
||||
@@ -98,7 +108,7 @@ class TestArgMaxOp(unittest.TestCase):
|
||||
})
|
||||
|
||||
argmax_node = Node(graph, 'argmax')
|
||||
ArgMaxOp.argmax_infer(argmax_node)
|
||||
arg_ops_infer(argmax_node)
|
||||
exp_shape = np.array([1, 2, 100])
|
||||
res_shape = graph.node['node_3']['shape']
|
||||
for i in range(0, len(exp_shape)):
|
||||
@@ -106,7 +116,9 @@ class TestArgMaxOp(unittest.TestCase):
|
||||
|
||||
def test_caffe_argmax_out_max_val_false(self):
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('node_1', 'argmax'),
|
||||
[
|
||||
('op_input', 'node_1'),
|
||||
('node_1', 'argmax'),
|
||||
('argmax', 'node_3'),
|
||||
('node_3', 'op_output')
|
||||
],
|
||||
@@ -119,27 +131,8 @@ class TestArgMaxOp(unittest.TestCase):
|
||||
})
|
||||
|
||||
argmax_node = Node(graph, 'argmax')
|
||||
ArgMaxOp.argmax_infer(argmax_node)
|
||||
arg_ops_infer(argmax_node)
|
||||
exp_shape = np.array([1, 1, 100])
|
||||
res_shape = graph.node['node_3']['shape']
|
||||
for i in range(0, len(exp_shape)):
|
||||
self.assertEqual(exp_shape[i], res_shape[i])
|
||||
|
||||
def test_caffe_argmax_no_shape(self):
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('node_1', 'argmax'),
|
||||
('argmax', 'node_3'),
|
||||
('node_3', 'op_output')
|
||||
],
|
||||
{'node_3': {'shape': None},
|
||||
'node_1': {'shape': None},
|
||||
'argmax': {
|
||||
'out_max_val': False,
|
||||
'top_k': 100
|
||||
}
|
||||
})
|
||||
|
||||
argmax_node = Node(graph, 'argmax')
|
||||
ArgMaxOp.argmax_infer(argmax_node)
|
||||
res_shape = graph.node['node_3']['shape']
|
||||
self.assertIsNone(res_shape)
|
||||
|
||||
Reference in New Issue
Block a user