Fixed type check in convert_model(). (#14472)
* Fixed passing Path to convert_model(). * Update tools/mo/openvino/tools/mo/convert_impl.py Co-authored-by: Ekaterina Aidova <ekaterina.aidova@intel.com> Co-authored-by: Ekaterina Aidova <ekaterina.aidova@intel.com>
This commit is contained in:
parent
532000a0ce
commit
60bb9e7b7c
@ -9,6 +9,7 @@ import platform
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -831,7 +832,7 @@ def show_mo_convert_help():
|
||||
|
||||
|
||||
def input_model_is_object(argv):
|
||||
if isinstance(argv['input_model'], str):
|
||||
if isinstance(argv['input_model'], (str, Path)):
|
||||
return False
|
||||
if argv['input_model'] is None:
|
||||
return False
|
||||
|
@ -3,8 +3,8 @@
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from generator import generator, generate
|
||||
from openvino.runtime import serialize
|
||||
|
||||
@ -19,6 +19,66 @@ from utils import create_onnx_model, save_to_onnx
|
||||
class ConvertImportMOTest(UnitTestWithMockedTelemetry):
|
||||
test_directory = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
@staticmethod
|
||||
def create_onnx_model():
|
||||
#
|
||||
# Create ONNX model
|
||||
#
|
||||
|
||||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto
|
||||
|
||||
shape = [1, 2, 3]
|
||||
|
||||
input = helper.make_tensor_value_info('input', TensorProto.FLOAT, shape)
|
||||
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, shape)
|
||||
|
||||
node_def = onnx.helper.make_node(
|
||||
'Relu',
|
||||
inputs=['input'],
|
||||
outputs=['Relu_out'],
|
||||
)
|
||||
node_def2 = onnx.helper.make_node(
|
||||
'Sigmoid',
|
||||
inputs=['Relu_out'],
|
||||
outputs=['output'],
|
||||
)
|
||||
|
||||
# Create the graph (GraphProto)
|
||||
graph_def = helper.make_graph(
|
||||
[node_def, node_def2],
|
||||
'test_model',
|
||||
[input],
|
||||
[output],
|
||||
)
|
||||
|
||||
# Create the model (ModelProto)
|
||||
onnx_net = helper.make_model(graph_def, producer_name='test_model')
|
||||
return onnx_net
|
||||
|
||||
@staticmethod
|
||||
def create_model_ref():
|
||||
nodes_attributes = {
|
||||
'input': {'kind': 'op', 'type': 'Parameter'},
|
||||
'input_data': {'shape': [1, 2, 3], 'kind': 'data'},
|
||||
'relu': {'kind': 'op', 'type': 'ReLU'},
|
||||
'relu_data': {'shape': [1, 2, 3], 'kind': 'data'},
|
||||
'sigmoid': {'kind': 'op', 'type': 'Sigmoid'},
|
||||
'sigmoid_data': {'shape': [1, 2, 3], 'kind': 'data'},
|
||||
'result': {'kind': 'op', 'type': 'Result'}
|
||||
}
|
||||
|
||||
ref_graph = build_graph(nodes_attributes,
|
||||
[('input', 'input_data'),
|
||||
('input_data', 'relu'),
|
||||
('relu', 'relu_data'),
|
||||
('relu_data', 'sigmoid'),
|
||||
('sigmoid', 'sigmoid_data'),
|
||||
('sigmoid_data', 'result'),
|
||||
])
|
||||
return ref_graph
|
||||
|
||||
@generate(*[
|
||||
({}),
|
||||
({'input': InputCutInfo(name='LeakyRelu_out', shape=None, type=None, value=None)}),
|
||||
@ -37,66 +97,26 @@ class ConvertImportMOTest(UnitTestWithMockedTelemetry):
|
||||
serialize(ov_model, out_xml.encode('utf-8'), out_xml.replace('.xml', '.bin').encode('utf-8'))
|
||||
assert os.path.exists(out_xml)
|
||||
|
||||
def test_input_model_path(self):
|
||||
from openvino.tools.mo import convert_model
|
||||
|
||||
with tempfile.TemporaryDirectory(dir=self.test_directory) as tmpdir:
|
||||
model = self.create_onnx_model()
|
||||
model_path = save_to_onnx(model, tmpdir)
|
||||
out_xml = os.path.join(tmpdir, Path("model.xml"))
|
||||
|
||||
ov_model = convert_model(input_model=model_path)
|
||||
serialize(ov_model, out_xml.encode('utf-8'), out_xml.replace('.xml', '.bin').encode('utf-8'))
|
||||
|
||||
ir = IREngine(out_xml, out_xml.replace('.xml', '.bin'))
|
||||
ref_graph = self.create_model_ref()
|
||||
flag, resp = ir.compare(ref_graph)
|
||||
assert flag, '\n'.join(resp)
|
||||
|
||||
def test_unnamed_input_model(self):
|
||||
def create_onnx_model():
|
||||
#
|
||||
# Create ONNX model
|
||||
#
|
||||
|
||||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto
|
||||
|
||||
shape = [1, 2, 3]
|
||||
|
||||
input = helper.make_tensor_value_info('input', TensorProto.FLOAT, shape)
|
||||
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, shape)
|
||||
|
||||
node_def = onnx.helper.make_node(
|
||||
'Relu',
|
||||
inputs=['input'],
|
||||
outputs=['Relu_out'],
|
||||
)
|
||||
node_def2 = onnx.helper.make_node(
|
||||
'Sigmoid',
|
||||
inputs=['Relu_out'],
|
||||
outputs=['output'],
|
||||
)
|
||||
|
||||
# Create the graph (GraphProto)
|
||||
graph_def = helper.make_graph(
|
||||
[node_def, node_def2],
|
||||
'test_model',
|
||||
[input],
|
||||
[output],
|
||||
)
|
||||
|
||||
# Create the model (ModelProto)
|
||||
onnx_net = helper.make_model(graph_def, producer_name='test_model')
|
||||
return onnx_net
|
||||
|
||||
nodes_attributes = {
|
||||
'input': {'kind': 'op', 'type': 'Parameter'},
|
||||
'input_data': {'shape': [1, 2, 3], 'kind': 'data'},
|
||||
'relu': {'kind': 'op', 'type': 'ReLU'},
|
||||
'relu_data': {'shape': [1, 2, 3], 'kind': 'data'},
|
||||
'sigmoid': {'kind': 'op', 'type': 'Sigmoid'},
|
||||
'sigmoid_data': {'shape': [1, 2, 3], 'kind': 'data'},
|
||||
'result': {'kind': 'op', 'type': 'Result'}
|
||||
}
|
||||
|
||||
ref_graph = build_graph(nodes_attributes,
|
||||
[('input', 'input_data'),
|
||||
('input_data', 'relu'),
|
||||
('relu', 'relu_data'),
|
||||
('relu_data', 'sigmoid'),
|
||||
('sigmoid', 'sigmoid_data'),
|
||||
('sigmoid_data', 'result'),
|
||||
])
|
||||
|
||||
from openvino.tools.mo import convert_model
|
||||
with tempfile.TemporaryDirectory(dir=self.test_directory) as tmpdir:
|
||||
model = create_onnx_model()
|
||||
model = self.create_onnx_model()
|
||||
model_path = save_to_onnx(model, tmpdir)
|
||||
out_xml = os.path.join(tmpdir, "model.xml")
|
||||
|
||||
@ -104,5 +124,6 @@ class ConvertImportMOTest(UnitTestWithMockedTelemetry):
|
||||
serialize(ov_model, out_xml.encode('utf-8'), out_xml.replace('.xml', '.bin').encode('utf-8'))
|
||||
|
||||
ir = IREngine(out_xml, out_xml.replace('.xml', '.bin'))
|
||||
ref_graph = self.create_model_ref()
|
||||
flag, resp = ir.compare(ref_graph)
|
||||
assert flag, '\n'.join(resp)
|
||||
|
Loading…
Reference in New Issue
Block a user