add getting nms_threshold / iou_threshold from RetinaNet (#3075)

* added getting nms_threshold/iou_threshold from original TF RetinaNet model

* iou_threshold definition added

* fixed getting iou_threshold for TF NMS V2, some minor corrections

* added box_encoding to NMS extractors
This commit is contained in:
Pavel Esir 2020-11-12 15:04:07 +03:00 committed by GitHub
parent 0f4525affc
commit 8c89d8d733
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 28 additions and 10 deletions

View File

@ -79,11 +79,10 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr
end.out_port(0).connect(shape_part_for_tiling.in_port(2))
stride.out_port(0).connect(shape_part_for_tiling.in_port(3))
concat_value = Const(graph, {'value': np.array([4])}).create_node()
shape_concat = Concat(graph, {'name': name + '/shape_for_tiling', 'in_ports_count': 2,
'axis': np.array(0)}).create_node()
shape_part_for_tiling.out_port(0).connect(shape_concat.in_port(0))
concat_value.out_port(0).connect(shape_concat.in_port(1))
shape_concat = create_op_node_with_second_input(graph, Concat, int64_array([4]),
{'name': name + '/shape_for_tiling', 'in_ports_count': 2,
'axis': int64_array(0)},
shape_part_for_tiling)
variance = Const(graph, {'name': name + '/variance', 'value': np.array(variance)}).create_node()
tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node()
@ -246,9 +245,19 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr
applied_width_height_regressions_node)
detection_output_op = DetectionOutput(graph, match.custom_replacement_desc.custom_attributes)
# get nms from the original network
iou_threshold = None
nms_nodes = graph.get_op_nodes(op='NonMaxSuppression')
if len(nms_nodes) > 0:
# it is highly unlikely that for different classes NMS has different
# moreover DetectionOutput accepts only scalar values for iou_threshold (nms_threshold)
iou_threshold = nms_nodes[0].in_node(3).value
if iou_threshold is None:
raise Error('During {} `iou_threshold` was not retrieved from RetinaNet graph'.format(self.replacement_id))
detection_output_node = detection_output_op.create_node(
[reshape_regression_node, reshape_classes_node, priors],
dict(name=detection_output_op.attrs['type'], clip_after_nms=1, normalized=1, variance_encoded_in_target=0,
background_label_id=1000))
dict(name=detection_output_op.attrs['type'], nms_threshold=iou_threshold, clip_after_nms=1, normalized=1,
variance_encoded_in_target=0, background_label_id=1000))
return {'detection_output_node': detection_output_node}

View File

@ -21,13 +21,24 @@ from extensions.ops.non_max_suppression import NonMaxSuppression
from mo.front.extractor import FrontExtractorOp
class NonMaxSuppressionV2Extractor(FrontExtractorOp):
op = 'NonMaxSuppressionV2'
enabled = True
@classmethod
def extract(cls, node):
attrs = {'sort_result_descending': 1, 'box_encoding': 'corner', 'output_type': np.int32}
NonMaxSuppression.update_node_stat(node, attrs)
return cls.enabled
class NonMaxSuppressionV3Extractor(FrontExtractorOp):
op = 'NonMaxSuppressionV3'
enabled = True
@classmethod
def extract(cls, node):
attrs = {'sort_result_descending': 1, 'center_point_box': 0, 'output_type': np.int32}
attrs = {'sort_result_descending': 1, 'box_encoding': 'corner', 'output_type': np.int32}
NonMaxSuppression.update_node_stat(node, attrs)
return cls.enabled

View File

@ -7,7 +7,6 @@
"confidence_threshold": 0.05,
"top_k": 6000,
"keep_top_k": 300,
"nms_threshold": 0.5,
"variance": [0.2, 0.2, 0.2, 0.2]
},
"include_inputs_to_sub_graph": true,

View File

@ -34,7 +34,6 @@ class NonMaxSuppression(Op):
'version': 'opset5',
'infer': self.infer,
'output_type': np.int64,
'center_point_box': 0,
'box_encoding': 'corner',
'in_ports_count': 5,
'sort_result_descending': 1,