convert_model() slow conversion time fix (#14666)
* Removed input_model from meta data dictionary. * Added test. * Changed check for more general case. * Test fixed. * Temporarily added debug print. * Fixed test. * Code corrections. * Small correction. * Added type check. * Added comments, small corrections. * Refactored MO convert_model() to have single parse_args(). * Small correction. * Fixed PyTorch converting. * Small correction. * Code refactoring, added tests.
This commit is contained in:
parent
e2635a0053
commit
993686b266
@ -427,6 +427,8 @@ def add_net_rt_info(net: Element, meta_info: dict):
|
||||
else:
|
||||
meta = SubElement(net, 'rt_info')
|
||||
for key, value in meta_info.items():
|
||||
if isinstance(value, dict) and value == {}:
|
||||
continue
|
||||
add_meta_data_elem(meta, key, value)
|
||||
|
||||
|
||||
|
@ -2,7 +2,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
from openvino.frontend import FrontEndManager
|
||||
|
||||
from openvino.tools.mo.convert_impl import _convert
|
||||
from openvino.tools.mo.utils.cli_parser import get_all_cli_parser
|
||||
|
||||
InputCutInfo = namedtuple("InputInfo", ["name", "shape", "type", "value"])
|
||||
LayoutMap = namedtuple("LayoutMap", ["source_layout", "target_layout"])
|
||||
@ -44,4 +48,9 @@ def convert_model(input_model=None, **args):
|
||||
openvino.runtime.Model
|
||||
"""
|
||||
args.update({'input_model': input_model})
|
||||
return _convert(**args)
|
||||
|
||||
cli_parser = get_all_cli_parser(FrontEndManager())
|
||||
framework = None if 'framework' not in args else args['framework']
|
||||
|
||||
ov_model, _ = _convert(cli_parser, framework, args)
|
||||
return ov_model
|
||||
|
@ -754,28 +754,18 @@ def args_dict_to_list(cli_parser, **kwargs):
|
||||
return result
|
||||
|
||||
|
||||
def pack_params_to_args_namespace(**kwargs):
|
||||
fe_manager = FrontEndManager()
|
||||
cli_parser = get_all_cli_parser(fe_manager)
|
||||
argv = cli_parser.parse_args(args_dict_to_list(cli_parser, **kwargs))
|
||||
|
||||
all_params = {}
|
||||
for key, value in mo_convert_params.items():
|
||||
all_params.update(value)
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if key not in argv and key not in all_params.keys():
|
||||
raise Error("Unrecognized argument: {}".format(key))
|
||||
if value is not None:
|
||||
setattr(argv, key, value)
|
||||
send_params_info(argv, cli_parser)
|
||||
|
||||
def get_non_default_params(argv, cli_parser):
|
||||
import numbers
|
||||
# make dictionary with parameters which have non-default values to be serialized in IR in rt_info
|
||||
non_default_params = {}
|
||||
for arg in vars(argv):
|
||||
arg_value = getattr(argv, arg)
|
||||
for arg, arg_value in vars(argv).items():
|
||||
if arg_value != cli_parser.get_default(arg):
|
||||
non_default_params[arg] = depersonalize(arg_value, arg)
|
||||
return argv, non_default_params
|
||||
value = depersonalize(arg_value, arg)
|
||||
# Skip complex classes in params to prevent
|
||||
# serializing it to rt_info
|
||||
if isinstance(value, (str, bool, numbers.Number)):
|
||||
non_default_params[arg] = value
|
||||
return non_default_params
|
||||
|
||||
|
||||
def params_to_string(**kwargs):
|
||||
@ -833,6 +823,10 @@ def show_mo_convert_help():
|
||||
|
||||
|
||||
def input_model_is_object(argv):
|
||||
# Input model can be set as object only for --input_model parameter.
|
||||
# --saved_model_dir or meta specific options are only used to store paths to the input model.
|
||||
if 'input_model' not in argv:
|
||||
return False
|
||||
if isinstance(argv['input_model'], (str, Path)):
|
||||
return False
|
||||
if argv['input_model'] is None:
|
||||
@ -840,6 +834,30 @@ def input_model_is_object(argv):
|
||||
return True
|
||||
|
||||
|
||||
def pack_params_to_args_namespace(args: dict, cli_parser: argparse.ArgumentParser):
|
||||
if len(args) > 0:
|
||||
args_string = params_to_string(**args)
|
||||
argv, _ = cli_parser.parse_known_args(args_dict_to_list(cli_parser, **args_string))
|
||||
|
||||
# get list of all available params for convert_model()
|
||||
all_params = {}
|
||||
for key, value in mo_convert_params.items():
|
||||
all_params.update(value)
|
||||
|
||||
# check that there are no unknown params provided
|
||||
for key, value in args_string.items():
|
||||
if key not in argv and key not in all_params.keys():
|
||||
raise Error("Unrecognized argument: {}".format(key))
|
||||
|
||||
# Non string params like input_model or extensions are ignored by parse_args()
|
||||
# so we need to set them in argv separately
|
||||
if value is not None and getattr(argv, key) != value:
|
||||
setattr(argv, key, value)
|
||||
else:
|
||||
argv = cli_parser.parse_args()
|
||||
return argv
|
||||
|
||||
|
||||
def remove_tmp_onnx_model(out_dir):
|
||||
if not os.environ.get('SAVE_TO_BYTES_IO_ONNX_MODEL'):
|
||||
tmp_onnx_model = get_onnx_temp_filename(out_dir)
|
||||
@ -848,7 +866,7 @@ def remove_tmp_onnx_model(out_dir):
|
||||
os.remove(tmp_onnx_model)
|
||||
|
||||
|
||||
def _convert(**args):
|
||||
def _convert(cli_parser: argparse.ArgumentParser, framework, args):
|
||||
if 'help' in args and args['help']:
|
||||
show_mo_convert_help()
|
||||
return None
|
||||
@ -886,19 +904,27 @@ def _convert(**args):
|
||||
args['onnx_opset_version'] = None
|
||||
|
||||
try:
|
||||
ov_model = _convert(**args)
|
||||
ov_model = _convert(cli_parser, framework, args)
|
||||
except Exception as e:
|
||||
remove_tmp_onnx_model(out_dir)
|
||||
raise e
|
||||
|
||||
remove_tmp_onnx_model(out_dir)
|
||||
return ov_model
|
||||
args = params_to_string(**args)
|
||||
argv, non_default_params = pack_params_to_args_namespace(**args)
|
||||
|
||||
argv = pack_params_to_args_namespace(args, cli_parser)
|
||||
|
||||
if framework is not None:
|
||||
setattr(argv, 'framework', framework)
|
||||
|
||||
# send telemetry with params info
|
||||
send_params_info(argv, cli_parser)
|
||||
|
||||
non_default_params = get_non_default_params(argv, cli_parser)
|
||||
|
||||
if inp_model_is_object:
|
||||
argv.model_name = "model"
|
||||
if argv.model_name is None:
|
||||
if not hasattr(argv, "model_name") or argv.model_name is None:
|
||||
argv.model_name = get_model_name_from_args(argv)
|
||||
|
||||
if model_framework is not None:
|
||||
@ -929,7 +955,7 @@ def _convert(**args):
|
||||
telemetry.send_event('mo', 'conversion_result', 'success')
|
||||
telemetry.end_session('mo')
|
||||
telemetry.force_shutdown(1.0)
|
||||
return ov_model
|
||||
return ov_model, argv
|
||||
except Exception as e:
|
||||
telemetry.send_event('mo', 'conversion_result', 'fail')
|
||||
telemetry.end_session('mo')
|
||||
|
@ -11,7 +11,7 @@ try:
|
||||
except ImportError:
|
||||
import openvino.tools.mo.utils.telemetry_stub as tm
|
||||
|
||||
from openvino.tools.mo.convert import convert_model
|
||||
from openvino.tools.mo.convert_impl import _convert
|
||||
from openvino.tools.mo.pipeline.common import get_ir_version
|
||||
from openvino.tools.mo.utils.cli_parser import get_model_name_from_args
|
||||
from openvino.tools.mo.utils.logger import init_logger
|
||||
@ -32,29 +32,20 @@ def main(cli_parser: argparse.ArgumentParser, framework=None):
|
||||
# Initialize logger with 'ERROR' as default level to be able to form nice messages
|
||||
# before arg parser deliver log_level requested by user
|
||||
init_logger('ERROR', False)
|
||||
logger = log.getLogger()
|
||||
# Disable logging for parse_args() as inner convert runs parse_args() second time
|
||||
# which result in duplicating of warnings
|
||||
logger.disabled = True
|
||||
argv = cli_parser.parse_args()
|
||||
logger.disabled = False
|
||||
argv.model_name = get_model_name_from_args(argv)
|
||||
is_tf, _, _, _, _ = deduce_legacy_frontend_by_namespace(argv)
|
||||
argv = vars(argv)
|
||||
|
||||
if framework is not None:
|
||||
argv['framework'] = framework
|
||||
|
||||
ngraph_function = None
|
||||
argv = None
|
||||
is_tf = False
|
||||
try:
|
||||
ngraph_function = convert_model(**argv)
|
||||
ngraph_function, argv = _convert(cli_parser, framework, {})
|
||||
is_tf, _, _, _, _ = deduce_legacy_frontend_by_namespace(argv)
|
||||
ov_update_message = get_ov_update_message()
|
||||
ov_api20_message = get_ov_api20_message()
|
||||
if ov_update_message is not None:
|
||||
print(ov_update_message)
|
||||
if ov_api20_message is not None and ngraph_function is not None:
|
||||
print(ov_api20_message)
|
||||
if argv['use_new_frontend'] and is_tf:
|
||||
if argv.use_new_frontend and is_tf:
|
||||
print(get_tf_fe_message())
|
||||
|
||||
except (FileNotFoundError, NotADirectoryError) as e:
|
||||
@ -66,7 +57,7 @@ def main(cli_parser: argparse.ArgumentParser, framework=None):
|
||||
for el in analysis_results.get_messages():
|
||||
log.error(el, extra={'analysis_info': True})
|
||||
log.error(err)
|
||||
if not argv['use_new_frontend'] and is_tf:
|
||||
if hasattr(argv, 'use_new_frontend') and not argv.use_new_frontend and is_tf:
|
||||
print(get_tf_fe_legacy_message())
|
||||
log.debug(traceback.format_exc())
|
||||
except FrameworkError as err:
|
||||
@ -81,21 +72,21 @@ def main(cli_parser: argparse.ArgumentParser, framework=None):
|
||||
log.error(traceback.format_exc())
|
||||
log.error("---------------- END OF BUG REPORT --------------")
|
||||
log.error("-------------------------------------------------")
|
||||
if not argv['use_new_frontend'] and is_tf:
|
||||
if hasattr(argv, 'use_new_frontend') and not argv.use_new_frontend and is_tf:
|
||||
print(get_tf_fe_legacy_message())
|
||||
|
||||
if ngraph_function is None:
|
||||
return 1
|
||||
|
||||
output_dir = argv['output_dir'] if argv['output_dir'] != '.' else os.getcwd()
|
||||
model_path_no_ext = os.path.normpath(os.path.join(output_dir, argv['model_name']))
|
||||
output_dir = argv.output_dir if argv.output_dir != '.' else os.getcwd()
|
||||
model_path_no_ext = os.path.normpath(os.path.join(output_dir, argv.model_name))
|
||||
model_path = model_path_no_ext + '.xml'
|
||||
|
||||
serialize(ngraph_function, model_path.encode('utf-8'), model_path.replace('.xml', '.bin').encode('utf-8'))
|
||||
|
||||
# generate .mapping file
|
||||
path_to_mapping = model_path_no_ext + ".mapping"
|
||||
extract_names = argv['framework'] in ['tf', 'mxnet', 'kaldi']
|
||||
extract_names = argv.framework in ['tf', 'mxnet', 'kaldi']
|
||||
generate_mapping_file(ngraph_function, path_to_mapping, extract_names)
|
||||
|
||||
print('[ SUCCESS ] Generated IR version {} model.'.format(get_ir_version(argv)))
|
||||
|
90
tools/mo/unit_tests/mo/convert/meta_data_test_actual.py
Normal file
90
tools/mo/unit_tests/mo/convert/meta_data_test_actual.py
Normal file
@ -0,0 +1,90 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from openvino.runtime import get_version as get_rt_version
|
||||
from openvino.runtime import serialize
|
||||
|
||||
from openvino.tools.mo import convert_model
|
||||
from openvino.tools.mo.utils.ir_reader.restore_graph import restore_graph_from_ir, save_restored_graph
|
||||
from openvino.tools.mo.utils.version import get_version
|
||||
|
||||
|
||||
class MetaDataTestTF(unittest.TestCase):
|
||||
test_directory = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
@staticmethod
|
||||
def check_meta_data(ov_model, ref_meta):
|
||||
ignore_attrs = ['version', 'optimization']
|
||||
for key, value in ref_meta.items():
|
||||
if key == 'conversion_parameters':
|
||||
for param_name, param_value in value.items():
|
||||
val = ov_model.get_rt_info([key, param_name])
|
||||
if param_name in ['extensions', 'caffe_parser_path', 'input_model', 'k', 'output_dir']:
|
||||
val = Path(val)
|
||||
assert val == param_value, \
|
||||
"Runtime info attribute with name {} does not match. Expected: {}, " \
|
||||
"got {}".format(param_name, param_value, val)
|
||||
continue
|
||||
assert str(ov_model.get_rt_info(key)) == value, \
|
||||
"Runtime info attribute with name {} does not match. Expected: {}, " \
|
||||
"got {}".format(key, value, ov_model.get_rt_info(key))
|
||||
|
||||
for key, value in ov_model.get_rt_info().items():
|
||||
if key in ignore_attrs:
|
||||
continue
|
||||
assert key in ref_meta, "Unexpected runtime info attribute: {}".format(key)
|
||||
|
||||
def test_meta_data_tf(self):
|
||||
def create_tf_model():
|
||||
import tensorflow as tf
|
||||
|
||||
tf.compat.v1.reset_default_graph()
|
||||
|
||||
with tf.compat.v1.Session() as sess:
|
||||
inp1 = tf.compat.v1.placeholder(tf.float32, [1, 2, 3], 'Input')
|
||||
inp2 = tf.compat.v1.placeholder(tf.float32, [1, 2, 3], 'Input')
|
||||
relu = tf.nn.relu(inp1 + inp2, name='Relu')
|
||||
|
||||
output = tf.nn.sigmoid(relu, name='Sigmoid')
|
||||
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
return tf_net
|
||||
|
||||
def ref_meta_data():
|
||||
return {
|
||||
'MO_version': get_version(),
|
||||
'Runtime_version': get_rt_version(),
|
||||
'legacy_frontend': "True",
|
||||
'conversion_parameters': {
|
||||
'scale': "1.5",
|
||||
'batch': "1"
|
||||
}
|
||||
}
|
||||
|
||||
with tempfile.TemporaryDirectory(dir=self.test_directory) as tmpdir:
|
||||
|
||||
model = create_tf_model()
|
||||
out_xml = os.path.join(tmpdir, "model.xml")
|
||||
ref_meta = ref_meta_data()
|
||||
|
||||
ov_model = convert_model(model, scale=1.5, batch=1)
|
||||
self.check_meta_data(ov_model, ref_meta)
|
||||
|
||||
serialize(ov_model, out_xml.encode('utf-8'), out_xml.replace('.xml', '.bin').encode('utf-8'))
|
||||
|
||||
from openvino.runtime import Core
|
||||
core = Core()
|
||||
deserialized_model = core.read_model(out_xml)
|
||||
self.check_meta_data(deserialized_model, ref_meta)
|
||||
|
||||
restored_graph, meta_data = restore_graph_from_ir(out_xml, out_xml.replace('.xml', '.bin'))
|
||||
save_restored_graph(restored_graph, tmpdir, meta_data, "mo_ir_reader_test_model")
|
||||
|
||||
mo_ir_reader_test_model = core.read_model(os.path.join(tmpdir, "mo_ir_reader_test_model.xml"))
|
||||
self.check_meta_data(mo_ir_reader_test_model, ref_meta)
|
@ -76,6 +76,15 @@ def test_main_error_log():
|
||||
assert test_log == ref_log
|
||||
|
||||
|
||||
def test_rt_info():
|
||||
setup_env()
|
||||
args = [sys.executable, '-m', 'pytest',
|
||||
os.path.join(os.path.dirname(__file__), 'convert/meta_data_test_actual.py'), '-s']
|
||||
|
||||
status = subprocess.run(args, env=os.environ, capture_output=True)
|
||||
assert not status.returncode
|
||||
|
||||
|
||||
def test_mo_extensions_test():
|
||||
setup_env()
|
||||
args = [sys.executable, '-m', 'pytest',
|
||||
|
@ -2,7 +2,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import numpy
|
||||
import os
|
||||
import pathlib
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
@ -14,10 +16,13 @@ import numpy as np
|
||||
from openvino.tools.mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_mean_scale_dictionary, \
|
||||
get_model_name, \
|
||||
parse_tuple_pairs, check_positive, writable_dir, readable_dirs, \
|
||||
readable_file, get_freeze_placeholder_values, parse_transform, check_available_transforms, get_layout_values, get_data_type_from_input_value
|
||||
readable_file, get_freeze_placeholder_values, parse_transform, check_available_transforms, get_layout_values, get_data_type_from_input_value, get_all_cli_parser
|
||||
from openvino.tools.mo.convert_impl import pack_params_to_args_namespace
|
||||
from openvino.tools.mo.convert import InputCutInfo, LayoutMap
|
||||
from openvino.tools.mo.utils.error import Error
|
||||
from unit_tests.mo.unit_test_with_mocked_telemetry import UnitTestWithMockedTelemetry
|
||||
from openvino.runtime import PartialShape, Dimension
|
||||
from openvino.runtime import PartialShape, Dimension, Layout
|
||||
from openvino.frontend import FrontEndManager
|
||||
|
||||
|
||||
class TestingMeanScaleGetter(UnitTestWithMockedTelemetry):
|
||||
@ -1955,3 +1960,56 @@ class TestLayoutParsingEmptyNamesNoBrackets(unittest.TestCase):
|
||||
def wrong_case_3(self):
|
||||
argv_source_layout = "nchv->"
|
||||
self.assertRaises(get_layout_values(argv_source_layout=argv_source_layout))
|
||||
|
||||
class TestPackParamsToArgsNamespace(unittest.TestCase):
|
||||
def test_mo_convert_params(self):
|
||||
from openvino.frontend import ConversionExtension
|
||||
args = {'input_model': os.path.dirname(__file__),
|
||||
'input_shape': [PartialShape([1,100,100,3]), [2,3]],
|
||||
'extensions': ConversionExtension("Ext", lambda x: x),
|
||||
'reverse_input_channels': True,
|
||||
'scale': 0.5,
|
||||
'input': ['name', InputCutInfo("a", [1,2,3], numpy.float32, [5, 6, 7])],
|
||||
'batch': 1,
|
||||
'output': ["a", "b", "c"],
|
||||
'mean_values': [0.5, 0.3],
|
||||
'scale_values': {"a": np.array([0.4]), "b": [0.5, 0.6]},
|
||||
'source_layout': Layout("nchw"),
|
||||
'layout': {"a": LayoutMap("nchw","nhwc"), "b": "nc"},
|
||||
'transform': ('LowLatency2', {'use_const_initializer': False})}
|
||||
|
||||
cli_parser = get_all_cli_parser(FrontEndManager())
|
||||
argv = pack_params_to_args_namespace(args, cli_parser)
|
||||
|
||||
assert argv.input_model == args['input_model']
|
||||
assert argv.extensions == [args['extensions']]
|
||||
assert argv.reverse_input_channels == args['reverse_input_channels']
|
||||
assert argv.scale == 0.5
|
||||
assert argv.batch == 1
|
||||
assert argv.input_shape == "[1,100,100,3],[2,3]"
|
||||
assert argv.input == "name,a[1 2 3]{f32}->[5 6 7]"
|
||||
assert argv.output == "a,b,c"
|
||||
assert argv.mean_values == "[0.5,0.3]"
|
||||
assert argv.scale_values == "a[0.4],b[0.5,0.6]"
|
||||
assert argv.source_layout == "[N,C,H,W]"
|
||||
assert argv.layout == "a(nchw->nhwc),b(nc)"
|
||||
assert argv.transform == "LowLatency2[use_const_initializer=False]"
|
||||
|
||||
for arg, value in vars(argv).items():
|
||||
if arg not in args:
|
||||
assert value == cli_parser.get_default(arg)
|
||||
|
||||
def test_not_existing_dir(self):
|
||||
args = {"input_model": "abc"}
|
||||
cli_parser = get_all_cli_parser(FrontEndManager())
|
||||
|
||||
with self.assertRaisesRegex(Error, "The \"abc\" is not existing file or directory"):
|
||||
pack_params_to_args_namespace(args, cli_parser)
|
||||
|
||||
def test_unknown_params(self):
|
||||
args = {"input_model": os.path.dirname(__file__),
|
||||
"a": "b"}
|
||||
cli_parser = get_all_cli_parser(FrontEndManager())
|
||||
|
||||
with self.assertRaisesRegex(Error, "Unrecognized argument: a"):
|
||||
pack_params_to_args_namespace(args, cli_parser)
|
||||
|
Loading…
Reference in New Issue
Block a user