Files
openvino/model-optimizer/unit_tests/extensions/middle/ArgOpsToTopK_test.py
Yegor Kruglov abb1ca657e Implementation of ArgMin ONNX + TF extractors (#5126)
* implement argmin extractors

* reconsidering argmax to topk

* arg ops refactoring

* rename ArgMaxToTopK

* added unittests

* update docs

* move unittest file to new folder

* conversations resolving

* revert changes with argmax.py, move argmin op to a new file

* rename ArgMaxSqueeze

* updated BOM file

* little fix

* code refactoring in ArgMaxOp, updated unittests

Co-authored-by: yegor.kruglov <ykruglov@nnlvdp-mkaglins.inn.intel.com>
2021-05-06 13:41:49 +03:00

152 lines
7.9 KiB
Python

# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
import numpy as np
from extensions.middle.ArgOpsToTopK import ArgOpsToTopK
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import regular_op_with_empty_data, result, build_graph, connect, \
valued_const_with_data, regular_op, empty_data, connect_front
nodes_attributes = {
**regular_op_with_empty_data('input', {'op': 'Parameter', 'type': 'Parameter'}),
**regular_op_with_empty_data('argmax', {'op': 'ArgMax', 'type': None, 'out_max_val': 0, 'top_k': 1, 'axis': 0,
'output_type': np.int32, 'remove_values_output': True}),
**regular_op_with_empty_data('argmin', {'op': 'ArgMin', 'type': None, 'top_k': 1, 'axis': 0,
'output_type': np.int32, 'remove_values_output': True}),
**result('result'),
**valued_const_with_data('axis_const', int64_array([1])),
**regular_op('topk', {'op': 'TopK', 'type': 'TopK', 'sort': 'index', 'index_element_type': np.int32}),
**empty_data('topk_out_0_data'),
**empty_data('topk_out_1_data'),
**regular_op_with_empty_data('topk_scalar', {'op': 'Const', 'type': 'Const', 'value': int64_array([1]),
'shape': []}),
**regular_op_with_empty_data('concat', {'op': 'Concat', 'type': 'Concat', 'axis': 1})
}
class ArgOpsToTopKTest(unittest.TestCase):
def test_tf_argmax_to_topk(self):
graph = build_graph(nodes_attrs=nodes_attributes,
edges=[
*connect('input', '0:argmax'),
*connect('axis_const', '1:argmax'),
*connect('argmax', 'result')
],
nodes_with_edges_only=True)
ArgOpsToTopK().find_and_replace_pattern(graph)
ref_graph = build_graph(nodes_attrs=nodes_attributes,
edges=[
*connect('input', '0:topk'),
*connect('topk_scalar', '1:topk'),
*connect_front('topk:1', 'topk_out_1_data'),
*connect_front('topk_out_1_data', 'result'),
],
update_attributes={
'topk': {'axis': int64_array([1]), 'mode': 'max', 'remove_values_output': True},
},
nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, ref_graph, 'input', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_tf_argmin_to_topk(self):
graph = build_graph(nodes_attrs=nodes_attributes,
edges=[
*connect('input', '0:argmin'),
*connect('axis_const', '1:argmin'),
*connect('argmin', 'result')
],
nodes_with_edges_only=True)
ArgOpsToTopK().find_and_replace_pattern(graph)
ref_graph = build_graph(nodes_attrs=nodes_attributes,
edges=[
*connect('input', '0:topk'),
*connect('topk_scalar', '1:topk'),
*connect_front('topk:1', 'topk_out_1_data'),
*connect_front('topk_out_1_data', 'result')
],
update_attributes={
'topk': {'axis': int64_array([1]), 'mode': 'min', 'remove_values_output': True},
},
nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, ref_graph, 'input', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_onnx_argmax_to_topk(self):
graph = build_graph(nodes_attrs=nodes_attributes,
edges=[
*connect('input', 'argmax'),
*connect('argmax', 'result')
],
nodes_with_edges_only=True)
ArgOpsToTopK().find_and_replace_pattern(graph)
ref_graph = build_graph(nodes_attrs=nodes_attributes,
edges=[
*connect('input', '0:topk'),
*connect('topk_scalar', '1:topk'),
*connect_front('topk:1', 'topk_out_1_data'),
*connect_front('topk_out_1_data', 'result')
],
update_attributes={
'topk': {'axis': 0, 'mode': 'max', 'remove_values_output': True},
},
nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, ref_graph, 'input', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_onnx_argmin_to_topk(self):
graph = build_graph(nodes_attrs=nodes_attributes,
edges=[
*connect('input', 'argmin'),
*connect('argmin', 'result')
],
nodes_with_edges_only=True)
ArgOpsToTopK().find_and_replace_pattern(graph)
ref_graph = build_graph(nodes_attrs=nodes_attributes,
edges=[
*connect('input', '0:topk'),
*connect('topk_scalar', '1:topk'),
*connect_front('topk:1', 'topk_out_1_data'),
*connect_front('topk_out_1_data', 'result')
],
update_attributes={
'topk': {'axis': 0, 'mode': 'min', 'remove_values_output': True},
},
nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, ref_graph, 'input', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_caffe_argmax_to_topk(self):
graph = build_graph(nodes_attrs=nodes_attributes,
edges=[
*connect('input', 'argmax'),
*connect('argmax', 'result')
],
update_attributes={
'argmax': {'out_max_val': 1}
},
nodes_with_edges_only=True)
ArgOpsToTopK().find_and_replace_pattern(graph)
ref_graph = build_graph(nodes_attrs=nodes_attributes,
edges=[
*connect('input', '0:topk'),
*connect('topk_scalar', '1:topk'),
*connect_front('topk:0','topk_out_0_data'),
*connect_front('topk:1', 'topk_out_1_data'),
*connect_front('topk_out_0_data', '1:concat'),
*connect_front('topk_out_1_data', '0:concat'),
*connect('concat', 'result')
],
update_attributes={
'topk': {'axis': 0, 'mode': 'max', 'remove_values_output': True},
},
nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, ref_graph, 'input', check_op_attrs=True)
self.assertTrue(flag, resp)