Files
openvino/tests/layer_tests/common/onnx_layer_test_class.py
2022-01-19 01:07:49 +03:00

64 lines
1.9 KiB
Python

# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import os
from common.layer_test_class import CommonLayerTest
from common.layer_utils import BaseInfer
def save_to_onnx(onnx_model, path_to_saved_onnx_model):
import onnx
path = os.path.join(path_to_saved_onnx_model, 'model.onnx')
onnx.save(onnx_model, path)
assert os.path.isfile(path), "model.onnx haven't been saved here: {}".format(path_to_saved_onnx_model)
return path
class Caffe2OnnxLayerTest(CommonLayerTest):
def produce_model_path(self, framework_model, save_path):
return save_to_onnx(framework_model, save_path)
def get_framework_results(self, inputs_dict, model_path):
# Evaluate model via Caffe2 and IE
# Load the ONNX model
import onnx
model = onnx.load(model_path)
# Run the ONNX model with Caffe2
import caffe2.python.onnx.backend
caffe2_res = caffe2.python.onnx.backend.run_model(model, inputs_dict)
res = dict()
for field in caffe2_res._fields:
res[field] = caffe2_res[field]
return res
class OnnxRuntimeInfer(BaseInfer):
def __init__(self, net):
super().__init__('OnnxRuntime')
self.net = net
def fw_infer(self, input_data):
import onnxruntime as rt
sess = rt.InferenceSession(self.net)
out = sess.run(None, input_data)
result = dict()
for i, output in enumerate(sess.get_outputs()):
result[output.name] = out[i]
if "sess" in locals():
del sess
return result
class OnnxRuntimeLayerTest(CommonLayerTest):
def produce_model_path(self, framework_model, save_path):
return save_to_onnx(framework_model, save_path)
def get_framework_results(self, inputs_dict, model_path):
ort = OnnxRuntimeInfer(net=model_path)
res = ort.infer(input_data=inputs_dict)
return res