Extend MO for the operation NonMaxSuppression-5 (#2356)
* Commit. * Written draft of NonMaxSuppression-5 class. * Written conversion of the value of the second output of MO NonMaxSuppression-5 into TF format. * Fixed type infer for the port 1 of NonMaxSuppression-5. * Added Reshape to [1] for 0D inputs of NMS-5. * Small fix. * Corrected assert for number of inputs. * Fixed docstrings for transformations TFNonMaxSuppressionNormalize and NonMaxSuppressionNormalize. * Now the transformation TFNonMaxSuppressionNormalize uses find_and_replace_pattern(). * Moved model-optimizer/extensions/front/onnx/non_max_suppression_normalize.py to model-optimizer/extensions/front/non_max_suppression_normalize.py, to delete duplicate code. * Deleted commented code. * Fixed BOM-file. * Deleted out_ports_count from NMS. * Fixes in type_infer of NMS-5. * Small changes. * Added some comment. * Small fix. * Some fixes.
This commit is contained in:
parent
e935d0bd22
commit
da47cb05be
@ -230,6 +230,7 @@ extensions/front/mxnet/where_ext.py
|
|||||||
extensions/front/mxnet/yolo_v3_mobilenet1_voc.json
|
extensions/front/mxnet/yolo_v3_mobilenet1_voc.json
|
||||||
extensions/front/mxnet/zeros_ext.py
|
extensions/front/mxnet/zeros_ext.py
|
||||||
extensions/front/no_op_eraser.py
|
extensions/front/no_op_eraser.py
|
||||||
|
extensions/front/non_max_suppression_normalize.py
|
||||||
extensions/front/OneHotDepthNormalizer.py
|
extensions/front/OneHotDepthNormalizer.py
|
||||||
extensions/front/onnx/__init__.py
|
extensions/front/onnx/__init__.py
|
||||||
extensions/front/onnx/activation_ext.py
|
extensions/front/onnx/activation_ext.py
|
||||||
@ -276,7 +277,6 @@ extensions/front/onnx/mask_rcnn_conversion.py
|
|||||||
extensions/front/onnx/matmul_ext.py
|
extensions/front/onnx/matmul_ext.py
|
||||||
extensions/front/onnx/mean_variance_normalization_ext.py
|
extensions/front/onnx/mean_variance_normalization_ext.py
|
||||||
extensions/front/onnx/non_max_suppression_ext.py
|
extensions/front/onnx/non_max_suppression_ext.py
|
||||||
extensions/front/onnx/non_max_suppression_normalize.py
|
|
||||||
extensions/front/onnx/non_zero_ext.py
|
extensions/front/onnx/non_zero_ext.py
|
||||||
extensions/front/onnx/normalize_ext.py
|
extensions/front/onnx/normalize_ext.py
|
||||||
extensions/front/onnx/normalize_l2_normalize.py
|
extensions/front/onnx/normalize_l2_normalize.py
|
||||||
|
@ -23,17 +23,17 @@ from mo.ops.reshape import Reshape
|
|||||||
|
|
||||||
class NonMaxSuppressionNormalize(FrontReplacementSubgraph):
|
class NonMaxSuppressionNormalize(FrontReplacementSubgraph):
|
||||||
"""
|
"""
|
||||||
The transformation converts several inputs of the NonMaxSuppression layer to be 0D instead of 1D with shape [1] to
|
The transformation converts several inputs of the NonMaxSuppression layer to be 1D instead of 0D with shape [1] to
|
||||||
comply with the layer specification.
|
comply with the layer specification.
|
||||||
"""
|
"""
|
||||||
enabled = True
|
enabled = True
|
||||||
|
|
||||||
def find_and_replace_pattern(self, graph: Graph):
|
def find_and_replace_pattern(self, graph: Graph):
|
||||||
for nms in graph.get_op_nodes(op='NonMaxSuppression'):
|
for nms in graph.get_op_nodes(op='NonMaxSuppression'):
|
||||||
# make inputs 2 to 4 to have shape [] instead of [1] (convert 1D to 0D)
|
# make inputs 2 to 5 to have shape [1] instead of [0] (convert 0D to 1D)
|
||||||
for port_id in range(2, 5):
|
nms_name = nms.soft_get('name', nms.id)
|
||||||
|
for port_id in range(2, 6):
|
||||||
if port_id in nms.in_ports() and not nms.in_port(port_id).disconnected():
|
if port_id in nms.in_ports() and not nms.in_port(port_id).disconnected():
|
||||||
reshape_1d = create_op_node_with_second_input(graph, Reshape, int64_array([]),
|
reshape_1d = create_op_node_with_second_input(graph, Reshape, int64_array([1]),
|
||||||
{'name': nms.soft_get('name') +
|
{'name': nms_name + '/Reshape_1D_{}'.format(port_id)})
|
||||||
'/Reshape_0D'.format(port_id)})
|
|
||||||
nms.in_port(port_id).get_connection().insert_node(reshape_1d)
|
nms.in_port(port_id).get_connection().insert_node(reshape_1d)
|
@ -34,27 +34,23 @@ class TFNonMaxSuppressionNormalize(FrontReplacementSubgraph):
|
|||||||
TF inputs: boxes = [num_boxes, 4]
|
TF inputs: boxes = [num_boxes, 4]
|
||||||
scores = [num_boxes]
|
scores = [num_boxes]
|
||||||
outputs: box_indices [selected_boxes_count]
|
outputs: box_indices [selected_boxes_count]
|
||||||
|
box_scores [selected_boxes_count]
|
||||||
|
valid_outputs selected_boxes_count
|
||||||
|
|
||||||
IE inputs: boxes = [num_batches, num_boxes, 4]
|
IE inputs: boxes = [num_batches, num_boxes, 4]
|
||||||
scores = [num_batches, num_classes, num_boxes]
|
scores = [num_batches, num_classes, num_boxes]
|
||||||
outputs: selected_indices [num_selected_indices, 3] where each element is [batch_index, class_index, box_index]
|
outputs: selected_indices [num_selected_indices, 3] where each element is [batch_index, class_index, box_index]
|
||||||
|
selected_scores [num_selected_indices, 3] where each element is [batch_index, class_index, box_score]
|
||||||
|
valid_outputs num_selected_indices
|
||||||
"""
|
"""
|
||||||
enabled = True
|
enabled = True
|
||||||
|
|
||||||
@staticmethod
|
def run_after(self):
|
||||||
def pattern(**kwargs):
|
from extensions.front.non_max_suppression_normalize import NonMaxSuppressionNormalize
|
||||||
return dict(
|
return [NonMaxSuppressionNormalize]
|
||||||
nodes=[
|
|
||||||
('nms', dict(op='NonMaxSuppression')),
|
|
||||||
],
|
|
||||||
edges=[
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def replace_sub_graph(graph: Graph, match: dict, **kwargs):
|
|
||||||
nms = match['nms']
|
|
||||||
|
|
||||||
|
def find_and_replace_pattern(self, graph: Graph):
|
||||||
|
for nms in graph.get_op_nodes(op='NonMaxSuppression'):
|
||||||
# prepare inputs to the NonMaximumSuppression Node
|
# prepare inputs to the NonMaximumSuppression Node
|
||||||
unsqueeze_boxes = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]),
|
unsqueeze_boxes = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]),
|
||||||
{'name': nms.soft_get('name') + '/Unsqueeze_0'})
|
{'name': nms.soft_get('name') + '/Unsqueeze_0'})
|
||||||
@ -64,18 +60,27 @@ class TFNonMaxSuppressionNormalize(FrontReplacementSubgraph):
|
|||||||
{'name': nms.soft_get('name') + '/Unsqueeze_1'})
|
{'name': nms.soft_get('name') + '/Unsqueeze_1'})
|
||||||
nms.in_port(1).get_connection().insert_node(unsqueeze_box_scores)
|
nms.in_port(1).get_connection().insert_node(unsqueeze_box_scores)
|
||||||
|
|
||||||
# prepare output
|
nms_name = nms.soft_get('name', nms.id)
|
||||||
crop_box_indices = Crop(graph, {'name': nms.soft_get('name') + '/Crop', 'axis': int64_array([1]),
|
|
||||||
|
# prepare output #0
|
||||||
|
crop_box_indices_name = nms_name + '/Crop_boxes_'
|
||||||
|
crop_box_indices = Crop(graph, {'name': crop_box_indices_name, 'axis': int64_array([1]),
|
||||||
'offset': int64_array([2]), 'dim': int64_array([1])}).create_node()
|
'offset': int64_array([2]), 'dim': int64_array([1])}).create_node()
|
||||||
nms.out_port(0).get_connection().insert_node(crop_box_indices)
|
nms.out_port(0).get_connection().insert_node(crop_box_indices)
|
||||||
squeeze_output_boxes = create_op_node_with_second_input(graph, Squeeze, int64_array([1]),
|
squeeze_output_boxes = create_op_node_with_second_input(graph, Squeeze, int64_array([1]),
|
||||||
{'name': crop_box_indices.soft_get('name') + '/Squeeze'}
|
{'name': crop_box_indices_name + '/Squeeze'})
|
||||||
)
|
|
||||||
crop_box_indices.out_port(0).get_connection().insert_node(squeeze_output_boxes)
|
crop_box_indices.out_port(0).get_connection().insert_node(squeeze_output_boxes)
|
||||||
|
|
||||||
if 5 in nms.in_ports() and not nms.in_port(5).disconnected():
|
num_of_outputs = len([port for port in nms.out_ports().values() if not port.disconnected()])
|
||||||
soft_nms_sigma = nms.in_port(5).get_source().data.get_value()
|
|
||||||
if soft_nms_sigma is not None and soft_nms_sigma != 0.0:
|
if num_of_outputs == 1:
|
||||||
log.error('The input to layer "{}" with value for the soft_nms_sigma is equal to "{}" but only value 0'
|
return
|
||||||
'is supported. The inference results will be incorrect.'.format(nms.soft_get('name'),
|
|
||||||
soft_nms_sigma))
|
# prepare output #1
|
||||||
|
crop_score_indices_name = nms_name + '/Crop_scores_'
|
||||||
|
crop_score_indices = Crop(graph, {'name': crop_score_indices_name, 'axis': int64_array([1]),
|
||||||
|
'offset': int64_array([2]), 'dim': int64_array([1])}).create_node()
|
||||||
|
nms.out_port(1).get_connection().insert_node(crop_score_indices)
|
||||||
|
squeeze_output_scores = create_op_node_with_second_input(graph, Squeeze, int64_array([1]),
|
||||||
|
{'name': crop_score_indices_name + '/Squeeze'})
|
||||||
|
crop_score_indices.out_port(0).get_connection().insert_node(squeeze_output_scores)
|
||||||
|
@ -31,23 +31,29 @@ class NonMaxSuppression(Op):
|
|||||||
mandatory_props = {
|
mandatory_props = {
|
||||||
'type': self.op,
|
'type': self.op,
|
||||||
'op': self.op,
|
'op': self.op,
|
||||||
'version': 'opset4',
|
'version': 'opset5',
|
||||||
'infer': self.infer,
|
'infer': self.infer,
|
||||||
'output_type': np.int64,
|
'output_type': np.int64,
|
||||||
'center_point_box': 0,
|
'center_point_box': 0,
|
||||||
'box_encoding': 'corner',
|
'box_encoding': 'corner',
|
||||||
'in_ports_count': 5,
|
'in_ports_count': 5,
|
||||||
'out_ports_count': 1,
|
|
||||||
'sort_result_descending': 1,
|
'sort_result_descending': 1,
|
||||||
'force_precision_in_ports': {
|
'force_precision_in_ports': {
|
||||||
2: 'int64'},
|
2: 'int64'},
|
||||||
'type_infer': self.type_infer,
|
'type_infer': self.type_infer,
|
||||||
}
|
}
|
||||||
super().__init__(graph, mandatory_props, attrs)
|
super().__init__(graph, mandatory_props, attrs)
|
||||||
|
version = self.get_opset()
|
||||||
|
if version in ['opset1', 'opset3', 'opset4']:
|
||||||
|
self.attrs['out_ports_count'] = 1
|
||||||
|
elif version == 'opset5':
|
||||||
|
self.attrs['out_ports_count'] = 3
|
||||||
|
else:
|
||||||
|
raise Error('Unsupported operation opset version "{}"'.format(version))
|
||||||
|
|
||||||
def backend_attrs(self):
|
def backend_attrs(self):
|
||||||
version = self.get_opset()
|
version = self.get_opset()
|
||||||
if version in ['opset3', 'opset4']:
|
if version in ['opset3', 'opset4', 'opset5']:
|
||||||
return ['sort_result_descending', 'box_encoding',
|
return ['sort_result_descending', 'box_encoding',
|
||||||
('output_type', lambda node: np_data_type_to_destination_type(node.output_type))]
|
('output_type', lambda node: np_data_type_to_destination_type(node.output_type))]
|
||||||
elif version == 'opset1':
|
elif version == 'opset1':
|
||||||
@ -57,6 +63,13 @@ class NonMaxSuppression(Op):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def infer(node: Node):
|
def infer(node: Node):
|
||||||
|
num_of_inputs = len(node.in_ports())
|
||||||
|
opset = node.get_opset()
|
||||||
|
max_num_of_inputs = 6 if opset == 'opset5' else 5
|
||||||
|
input_msg_fmt = 'NonMaxSuppression node {} from {} must have from 2 to {} inputs'
|
||||||
|
inputs_msg = input_msg_fmt.format(node.soft_get('name', node.id), opset, max_num_of_inputs)
|
||||||
|
assert 2 <= num_of_inputs <= max_num_of_inputs, inputs_msg
|
||||||
|
|
||||||
boxes_shape = node.in_port(0).data.get_shape()
|
boxes_shape = node.in_port(0).data.get_shape()
|
||||||
assert boxes_shape is not None, 'The shape of tensor with boxes is not defined'
|
assert boxes_shape is not None, 'The shape of tensor with boxes is not defined'
|
||||||
scores_shape = node.in_port(1).data.get_shape()
|
scores_shape = node.in_port(1).data.get_shape()
|
||||||
@ -64,8 +77,14 @@ class NonMaxSuppression(Op):
|
|||||||
assert len(boxes_shape) == 3, 'Length of tensors with boxes must be equal to 3'
|
assert len(boxes_shape) == 3, 'Length of tensors with boxes must be equal to 3'
|
||||||
assert len(scores_shape) == 3, 'Length of tensors with scores must be equal to 3'
|
assert len(scores_shape) == 3, 'Length of tensors with scores must be equal to 3'
|
||||||
|
|
||||||
|
# According to the specification of the operation NonMaxSuppression,
|
||||||
|
# the input 'max_output_boxes_per_class' (port 2) is optional, with default value 0.
|
||||||
|
if num_of_inputs >= 3:
|
||||||
max_output_boxes_per_class = node.in_port(2).data.get_value()
|
max_output_boxes_per_class = node.in_port(2).data.get_value()
|
||||||
if max_output_boxes_per_class is None:
|
else:
|
||||||
|
max_output_boxes_per_class = 0
|
||||||
|
|
||||||
|
if not max_output_boxes_per_class:
|
||||||
log.info('Set default "max_output_boxes_per_class" for node {} to number of boxes'.format(node.name))
|
log.info('Set default "max_output_boxes_per_class" for node {} to number of boxes'.format(node.name))
|
||||||
max_output_boxes_per_class = boxes_shape[1]
|
max_output_boxes_per_class = boxes_shape[1]
|
||||||
|
|
||||||
@ -73,15 +92,29 @@ class NonMaxSuppression(Op):
|
|||||||
num_input_boxes = boxes_shape[1]
|
num_input_boxes = boxes_shape[1]
|
||||||
assert scores_shape[2] == num_input_boxes, 'Number of boxes mismatch'
|
assert scores_shape[2] == num_input_boxes, 'Number of boxes mismatch'
|
||||||
|
|
||||||
if node.get_opset() == 'opset4':
|
if node.get_opset() in ['opset4', 'opset5']:
|
||||||
max_number_of_boxes = min(num_input_boxes, max_output_boxes_per_class) * boxes_shape[0] * num_classes
|
max_number_of_boxes = min(num_input_boxes, max_output_boxes_per_class) * boxes_shape[0] * num_classes
|
||||||
else:
|
else:
|
||||||
max_number_of_boxes = min(num_input_boxes, boxes_shape[0] * max_output_boxes_per_class * num_classes)
|
max_number_of_boxes = min(num_input_boxes, boxes_shape[0] * max_output_boxes_per_class * num_classes)
|
||||||
node.out_port(0).data.set_shape(int64_array([max_number_of_boxes, 3]))
|
node.out_port(0).data.set_shape(int64_array([max_number_of_boxes, 3]))
|
||||||
|
|
||||||
|
if opset == 'opset5':
|
||||||
|
num_of_outputs = len([port for port in node.out_ports().values() if not port.disconnected()])
|
||||||
|
if num_of_outputs >= 2 and node.has_port('out', 1):
|
||||||
|
node.out_port(1).data.set_shape(int64_array([max_number_of_boxes, 3]))
|
||||||
|
if num_of_outputs >= 3 and node.has_port('out', 2):
|
||||||
|
node.out_port(2).data.set_shape(int64_array(1))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def type_infer(node):
|
def type_infer(node):
|
||||||
if node.get_opset() in ['opset3', 'opset4']:
|
opset = node.get_opset()
|
||||||
|
if opset == 'opset5':
|
||||||
|
node.out_port(0).set_data_type(node.output_type)
|
||||||
|
if node.has_port('out', 1):
|
||||||
|
node.out_port(1).set_data_type(np.float32)
|
||||||
|
if node.has_port('out', 2):
|
||||||
|
node.out_port(2).set_data_type(np.int64)
|
||||||
|
elif opset in ['opset3', 'opset4']:
|
||||||
node.out_port(0).set_data_type(node.output_type)
|
node.out_port(0).set_data_type(node.output_type)
|
||||||
else:
|
else:
|
||||||
node.out_port(0).set_data_type(np.int64)
|
node.out_port(0).set_data_type(np.int64)
|
||||||
|
Loading…
Reference in New Issue
Block a user