set num_classes to None in multibox_detection extractor (#7860)

This commit is contained in:
Yegor Kruglov 2021-10-26 00:02:10 +03:00 committed by GitHub
parent 64a0e3dbd0
commit 345c3510f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 3 deletions

View File

@ -13,7 +13,10 @@ class MultiBoxDetectionOutputExtractor(FrontExtractorOp):
@classmethod
def extract(cls, node):
attrs = get_mxnet_layer_attrs(node.symbol_dict)
num_classes = 21
# We can not get num_classes attribute from the operation, so it must be set to None.
# In this case num_classes attribute will be defined in the infer function in
# mo/front/common/partial_infer/multi_box_detection.py
num_classes = None
top_k = attrs.int("nms_topk", -1)
keep_top_k = top_k
variance_encoded_in_target = 0

View File

@ -21,7 +21,7 @@ class TestMultiBoxDetection_Parsing(unittest.TestCase):
exp_attrs = {
'type': 'DetectionOutput',
'num_classes': 21,
'num_classes': None,
'keep_top_k': 400,
'variance_encoded_in_target': 0,
'code_type': "caffe.PriorBoxParameter.CENTER_SIZE",
@ -51,7 +51,7 @@ class TestMultiBoxDetection_Parsing(unittest.TestCase):
exp_attrs = {
'type': 'DetectionOutput',
'num_classes': 21,
'num_classes': None,
'keep_top_k': -1,
'variance_encoded_in_target': 0,
'code_type': "caffe.PriorBoxParameter.CENTER_SIZE",