[MO] Fix NMS 3rd output shape (#7992)

* Fix NMS 3rd output shape

* Add tests for NMS_5 shape infer

* Add comments, fix codestyle
This commit is contained in:
Anton Chetverikov 2021-10-15 14:44:08 +03:00 committed by GitHub
parent e8f2249d8e
commit e034a072ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 117 additions and 6 deletions

View File

@ -101,7 +101,7 @@ class NonMaxSuppression(Op):
if num_of_outputs >= 2 and node.has_port('out', 1):
node.out_port(1).data.set_shape(shape_array([dynamic_dimension_value, 3]))
if num_of_outputs >= 3 and node.has_port('out', 2):
node.out_port(2).data.set_shape(shape_array(1))
node.out_port(2).data.set_shape(shape_array([1]))
@staticmethod
def type_infer(node):

View File

@ -8,7 +8,8 @@ import numpy as np
from extensions.ops.non_max_suppression import NonMaxSuppression
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node
from unit_tests.utils.graph import build_graph, regular_op_with_shaped_data, valued_const_with_data, result, connect
from unit_tests.utils.graph import build_graph, regular_op, regular_op_with_shaped_data, valued_const_with_data, result, connect, empty_data
from mo.front.common.partial_infer.utils import shape_array, dynamic_dimension_value
class TestNonMaxSuppressionInfer(unittest.TestCase):
@ -17,16 +18,44 @@ class TestNonMaxSuppressionInfer(unittest.TestCase):
**regular_op_with_shaped_data('boxes', [10, 100, 4], {'type': 'Parameter'}),
**regular_op_with_shaped_data('scores', [10, 5, 100], {'type': 'Parameter'}),
**valued_const_with_data('max_output_per_class', int64_array(7)),
**regular_op_with_shaped_data('nms', None, {'op': 'NonMaxSuppression', 'type': 'NonMaxSuppression',
'name': 'nms'}),
**result('output'),
**regular_op('nms', {'op': 'NonMaxSuppression', 'type': 'NonMaxSuppression', 'name': 'nms'}),
**empty_data('nms_data_0'),
**empty_data('nms_data_1'),
**empty_data('nms_data_2'),
**result('output_0'),
**result('output_1'),
**result('output_2'),
}
self.graph = build_graph(nodes, [
*connect('boxes', '0:nms'),
*connect('scores', '1:nms'),
*connect('max_output_per_class', '2:nms'),
*connect('nms', 'output'),
*connect('nms:0', 'nms_data_0', front_phase=True), # Use this WA for correct creating operation
*connect('nms_data_0', 'output_0', front_phase=True), # with multiple outputs
], nodes_with_edges_only=True)
self.graph_nms_5_2_outs = build_graph(nodes, [
*connect('boxes', '0:nms'),
*connect('scores', '1:nms'),
*connect('max_output_per_class', '2:nms'),
*connect('nms:0', 'nms_data_0', front_phase=True), # Use this WA for correct creating operation
*connect('nms_data_0', 'output_0', front_phase=True), # with multiple outputs
*connect('nms:1', 'nms_data_1', front_phase=True),
*connect('nms_data_1', 'output_1', front_phase=True),
], nodes_with_edges_only=True)
self.graph_nms_5_3_outs = build_graph(nodes, [
*connect('boxes', '0:nms'),
*connect('scores', '1:nms'),
*connect('max_output_per_class', '2:nms'),
*connect('nms:0', 'nms_data_0', front_phase=True), # Use this WA for correct creating operation
*connect('nms_data_0', 'output_0', front_phase=True), # with multiple outputs
*connect('nms:1', 'nms_data_1', front_phase=True),
*connect('nms_data_1', 'output_1', front_phase=True),
*connect('nms:2', 'nms_data_2', front_phase=True),
*connect('nms_data_2', 'output_2', front_phase=True),
], nodes_with_edges_only=True)
def test_nms_infer_opset1(self):
@ -77,3 +106,85 @@ class TestNonMaxSuppressionInfer(unittest.TestCase):
self.assertTrue(np.array_equal(nms_node.out_port(0).data.get_shape(), [10 * 5 * 7, 3]))
self.assertTrue(nms_node.out_port(0).get_data_type() == np.int64)
def test_nms_infer_i32_opset5_1_out(self):
nms_node = Node(self.graph, 'nms')
nms_node['version'] = 'opset5'
nms_node['output_type'] = np.int32
NonMaxSuppression.infer(nms_node)
NonMaxSuppression.type_infer(nms_node)
self.assertTrue(np.array_equal(nms_node.out_port(0).data.get_shape(),
shape_array([dynamic_dimension_value, 3])))
self.assertTrue(nms_node.out_port(0).get_data_type() == np.int32)
def test_nms_infer_i64_opset5_1_out(self):
nms_node = Node(self.graph, 'nms')
nms_node['version'] = 'opset5'
nms_node['output_type'] = np.int64
NonMaxSuppression.infer(nms_node)
NonMaxSuppression.type_infer(nms_node)
self.assertTrue(np.array_equal(nms_node.out_port(0).data.get_shape(),
shape_array([dynamic_dimension_value, 3])))
self.assertTrue(nms_node.out_port(0).get_data_type() == np.int64)
def test_nms_infer_i32_opset5_2_outs(self):
nms_node = Node(self.graph_nms_5_2_outs, 'nms')
nms_node['version'] = 'opset5'
nms_node['output_type'] = np.int32
NonMaxSuppression.infer(nms_node)
NonMaxSuppression.type_infer(nms_node)
self.assertTrue(np.array_equal(nms_node.out_port(0).data.get_shape(),
shape_array([dynamic_dimension_value, 3])))
self.assertTrue(np.array_equal(nms_node.out_port(1).data.get_shape(),
shape_array([dynamic_dimension_value, 3])))
self.assertTrue(nms_node.out_port(0).get_data_type() == np.int32)
self.assertTrue(nms_node.out_port(1).get_data_type() == np.float32)
def test_nms_infer_i64_opset5_2_outs(self):
nms_node = Node(self.graph_nms_5_2_outs, 'nms')
nms_node['version'] = 'opset5'
nms_node['output_type'] = np.int64
NonMaxSuppression.infer(nms_node)
NonMaxSuppression.type_infer(nms_node)
self.assertTrue(np.array_equal(nms_node.out_port(0).data.get_shape(),
shape_array([dynamic_dimension_value, 3])))
self.assertTrue(np.array_equal(nms_node.out_port(1).data.get_shape(),
shape_array([dynamic_dimension_value, 3])))
self.assertTrue(nms_node.out_port(0).get_data_type() == np.int64)
self.assertTrue(nms_node.out_port(1).get_data_type() == np.float32)
def test_nms_infer_i32_opset5_3_outs(self):
nms_node = Node(self.graph_nms_5_3_outs, 'nms')
nms_node['version'] = 'opset5'
nms_node['output_type'] = np.int32
NonMaxSuppression.infer(nms_node)
NonMaxSuppression.type_infer(nms_node)
self.assertTrue(np.array_equal(nms_node.out_port(0).data.get_shape(),
shape_array([dynamic_dimension_value, 3])))
self.assertTrue(np.array_equal(nms_node.out_port(1).data.get_shape(),
shape_array([dynamic_dimension_value, 3])))
self.assertTrue(np.array_equal(nms_node.out_port(2).data.get_shape(), [1]))
self.assertTrue(nms_node.out_port(0).get_data_type() == np.int32)
self.assertTrue(nms_node.out_port(1).get_data_type() == np.float32)
self.assertTrue(nms_node.out_port(2).get_data_type() == np.int64)
def test_nms_infer_i64_opset5_3_outs(self):
nms_node = Node(self.graph_nms_5_3_outs, 'nms')
nms_node['version'] = 'opset5'
nms_node['output_type'] = np.int64
NonMaxSuppression.infer(nms_node)
NonMaxSuppression.type_infer(nms_node)
self.assertTrue(np.array_equal(nms_node.out_port(0).data.get_shape(),
shape_array([dynamic_dimension_value, 3])))
self.assertTrue(np.array_equal(nms_node.out_port(1).data.get_shape(),
shape_array([dynamic_dimension_value, 3])))
self.assertTrue(np.array_equal(nms_node.out_port(2).data.get_shape(), [1]))
self.assertTrue(nms_node.out_port(0).get_data_type() == np.int64)
self.assertTrue(nms_node.out_port(1).get_data_type() == np.float32)
self.assertTrue(nms_node.out_port(2).get_data_type() == np.int64)