Enable yolo v3 test, added post-processing step (#3510)
This commit is contained in:
parent
c80e3c3a82
commit
261cb6ecf8
@ -221,7 +221,6 @@ xfail_issue_39662 = xfail_test(reason="RuntimeError: 'ScatterElementsUpdate' lay
|
||||
"indices value that points to non-existing output tensor element")
|
||||
xfail_issue_39663 = xfail_test(reason="RuntimeError: Unsupported primitive of type: ROIAlign name: Y")
|
||||
xfail_issue_43380 = xfail_test(reason="RuntimeError: Sorting not possible, due to existed loop")
|
||||
xfail_issue_43382 = xfail_test(reason="Testing models which have upper bound output shape is not supported")
|
||||
xfail_issue_41894 = xfail_test(reason="CPU plugin elementwise computation missmatch")
|
||||
|
||||
|
||||
|
@ -19,6 +19,8 @@ import tests
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
import os
|
||||
from typing import Sequence, Any
|
||||
import numpy as np
|
||||
|
||||
from tests.test_onnx.utils import OpenVinoOnnxBackend
|
||||
from tests.test_onnx.utils.model_importer import ModelImportRunner
|
||||
@ -27,7 +29,6 @@ from tests import (
|
||||
xfail_issue_38701,
|
||||
xfail_issue_43742,
|
||||
xfail_issue_43380,
|
||||
xfail_issue_43382,
|
||||
xfail_issue_43439,
|
||||
xfail_issue_39684,
|
||||
xfail_issue_40957,
|
||||
@ -46,6 +47,18 @@ from tests import (
|
||||
|
||||
MODELS_ROOT_DIR = tests.MODEL_ZOO_DIR
|
||||
|
||||
def yolov3_post_processing(outputs : Sequence[Any]) -> Sequence[Any]:
|
||||
concat_out_index = 2
|
||||
# remove all elements with value -1 from yolonms_layer_1/concat_2:0 output
|
||||
concat_out = outputs[concat_out_index][outputs[concat_out_index] != -1]
|
||||
concat_out = np.expand_dims(concat_out, axis=0)
|
||||
outputs[concat_out_index] = concat_out
|
||||
return outputs
|
||||
|
||||
post_processing = {
|
||||
"yolov3" : {"post_processing" : yolov3_post_processing}
|
||||
}
|
||||
|
||||
tolerance_map = {
|
||||
"arcface_lresnet100e_opset8": {"atol": 0.001, "rtol": 0.001},
|
||||
"fp16_inception_v1": {"atol": 0.001, "rtol": 0.001},
|
||||
@ -117,6 +130,8 @@ for path in Path(MODELS_ROOT_DIR).rglob("*.onnx"):
|
||||
# updated model looks now:
|
||||
# {"model_name": path, "model_file": file, "dir": mdir, "atol": ..., "rtol": ...}
|
||||
model.update(tolerance_map[basedir])
|
||||
if basedir in post_processing:
|
||||
model.update(post_processing[basedir])
|
||||
zoo_models.append(model)
|
||||
|
||||
if len(zoo_models) > 0:
|
||||
@ -163,7 +178,6 @@ if len(zoo_models) > 0:
|
||||
(xfail_issue_39669, "test_onnx_model_zoo_text_machine_comprehension_t5_model_t5_encoder_12_t5_encoder_cpu"),
|
||||
(xfail_issue_38084, "test_onnx_model_zoo_vision_object_detection_segmentation_mask_rcnn_model_MaskRCNN_10_mask_rcnn_R_50_FPN_1x_cpu"),
|
||||
(xfail_issue_38084, "test_onnx_model_zoo_vision_object_detection_segmentation_faster_rcnn_model_FasterRCNN_10_faster_rcnn_R_50_FPN_1x_cpu"),
|
||||
(xfail_issue_43382, "test_onnx_model_zoo_vision_object_detection_segmentation_yolov3_model_yolov3_10_yolov3_yolov3_cpu"),
|
||||
(xfail_issue_43380, "test_onnx_model_zoo_vision_object_detection_segmentation_tiny_yolov3_model_tiny_yolov3_11_yolov3_tiny_cpu"),
|
||||
|
||||
# Model MSFT
|
||||
@ -182,8 +196,7 @@ if len(zoo_models) > 0:
|
||||
(xfail_issue_39669, "test_MSFT_opset9_cgan_cgan_cpu"),
|
||||
(xfail_issue_40957, "test_MSFT_opset10_BERT_Squad_bertsquad10_cpu"),
|
||||
|
||||
(xfail_issue_43380, "test_MSFT_opset11_tinyyolov3_yolov3_tiny_cpu"),
|
||||
(xfail_issue_43382, "test_MSFT_opset10_yolov3_yolov3_cpu"),
|
||||
(xfail_issue_43380, "test_MSFT_opset11_tinyyolov3_yolov3_tiny_cpu")
|
||||
|
||||
]
|
||||
for test_case in import_xfail_list + execution_xfail_list:
|
||||
|
@ -19,14 +19,18 @@ import onnx
|
||||
import onnx.backend.test
|
||||
import unittest
|
||||
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, namedtuple
|
||||
from onnx import numpy_helper, NodeProto, ModelProto
|
||||
from onnx.backend.base import Backend, BackendRep
|
||||
from onnx.backend.test.case.test_case import TestCase as OnnxTestCase
|
||||
from onnx.backend.test.runner import TestItem
|
||||
from pathlib import Path
|
||||
from tests.test_onnx.utils.onnx_helpers import import_onnx_model
|
||||
from typing import Any, Dict, List, Optional, Pattern, Set, Text, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Pattern, Set, Text, Type, Union, Callable, Sequence
|
||||
|
||||
|
||||
# add post-processing function as part of test data
|
||||
ExtOnnxTestCase = namedtuple("TestCaseExt", OnnxTestCase._fields + ("post_processing",))
|
||||
|
||||
|
||||
class ModelImportRunner(onnx.backend.test.BackendTest):
|
||||
@ -51,7 +55,7 @@ class ModelImportRunner(onnx.backend.test.BackendTest):
|
||||
.replace("\\", "_") \
|
||||
.replace("-", "_")
|
||||
|
||||
test_case = OnnxTestCase(
|
||||
test_case = ExtOnnxTestCase(
|
||||
name=test_name,
|
||||
url=None,
|
||||
model_name=model["model_name"],
|
||||
@ -61,6 +65,7 @@ class ModelImportRunner(onnx.backend.test.BackendTest):
|
||||
kind="OnnxBackendRealModelTest",
|
||||
rtol=model.get("rtol", 0.001),
|
||||
atol=model.get("atol", 1e-07),
|
||||
post_processing=model.get("post_processing", None)
|
||||
)
|
||||
self._add_model_import_test(test_case)
|
||||
self._add_model_execution_test(test_case)
|
||||
@ -72,7 +77,7 @@ class ModelImportRunner(onnx.backend.test.BackendTest):
|
||||
|
||||
return onnx.load(model_dir / filename)
|
||||
|
||||
def _add_model_import_test(self, model_test: OnnxTestCase) -> None:
|
||||
def _add_model_import_test(self, model_test: ExtOnnxTestCase) -> None:
|
||||
# model is loaded at runtime, note sometimes it could even
|
||||
# never loaded if the test skipped
|
||||
model_marker = [None] # type: List[Optional[Union[ModelProto, NodeProto]]]
|
||||
@ -87,6 +92,7 @@ class ModelImportRunner(onnx.backend.test.BackendTest):
|
||||
@classmethod
|
||||
def _execute_npz_data(
|
||||
cls, model_dir: str, prepared_model: BackendRep, result_rtol: float, result_atol: float,
|
||||
post_processing: Callable[[Sequence[Any]], Sequence[Any]] = None
|
||||
) -> int:
|
||||
executed_tests = 0
|
||||
for test_data_npz in model_dir.glob("test_data_*.npz"):
|
||||
@ -94,6 +100,8 @@ class ModelImportRunner(onnx.backend.test.BackendTest):
|
||||
inputs = list(test_data["inputs"])
|
||||
outputs = list(prepared_model.run(inputs))
|
||||
ref_outputs = test_data["outputs"]
|
||||
if post_processing is not None:
|
||||
outputs = post_processing(outputs)
|
||||
cls.assert_similar_outputs(ref_outputs, outputs, result_rtol, result_atol)
|
||||
executed_tests = executed_tests + 1
|
||||
return executed_tests
|
||||
@ -101,6 +109,7 @@ class ModelImportRunner(onnx.backend.test.BackendTest):
|
||||
@classmethod
|
||||
def _execute_pb_data(
|
||||
cls, model_dir: str, prepared_model: BackendRep, result_rtol: float, result_atol: float,
|
||||
post_processing: Callable[[Sequence[Any]], Sequence[Any]] = None
|
||||
) -> int:
|
||||
executed_tests = 0
|
||||
for test_data_dir in model_dir.glob("test_data_set*"):
|
||||
@ -123,11 +132,13 @@ class ModelImportRunner(onnx.backend.test.BackendTest):
|
||||
if(len(inputs) == 0):
|
||||
continue
|
||||
outputs = list(prepared_model.run(inputs))
|
||||
if post_processing is not None:
|
||||
outputs = post_processing(outputs)
|
||||
cls.assert_similar_outputs(ref_outputs, outputs, result_rtol, result_atol)
|
||||
executed_tests = executed_tests + 1
|
||||
return executed_tests
|
||||
|
||||
def _add_model_execution_test(self, model_test: OnnxTestCase) -> None:
|
||||
def _add_model_execution_test(self, model_test: ExtOnnxTestCase) -> None:
|
||||
# model is loaded at runtime, note sometimes it could even
|
||||
# never loaded if the test skipped
|
||||
model_marker = [None] # type: List[Optional[Union[ModelProto, NodeProto]]]
|
||||
@ -138,12 +149,13 @@ class ModelImportRunner(onnx.backend.test.BackendTest):
|
||||
prepared_model = self.backend.prepare(model, device)
|
||||
assert prepared_model is not None
|
||||
executed_tests = ModelImportRunner._execute_npz_data(
|
||||
model_test.model_dir, prepared_model, model_test.rtol, model_test.atol
|
||||
model_test.model_dir, prepared_model, model_test.rtol, model_test.atol,
|
||||
model_test.post_processing
|
||||
)
|
||||
|
||||
executed_tests = executed_tests + ModelImportRunner._execute_pb_data(
|
||||
model_test.model_dir, prepared_model, model_test.rtol, model_test.atol
|
||||
model_test.model_dir, prepared_model, model_test.rtol, model_test.atol,
|
||||
model_test.post_processing
|
||||
)
|
||||
|
||||
assert executed_tests > 0, "This model has no test data"
|
||||
self._add_test("ModelExecution", model_test.name, run_execution, model_marker)
|
||||
|
Loading…
Reference in New Issue
Block a user