convert_model() slow conversion time fix (#14666)
* Removed input_model from meta data dictionary. * Added test. * Changed check for more general case. * Test fixed. * Temporarily added debug print. * Fixed test. * Code corrections. * Small correction. * Added type check. * Added comments, small corrections. * Refactored MO convert_model() to have single parse_args(). * Small correction. * Fixed PyTorch converting. * Small correction. * Code refactoring, added tests.
This commit is contained in:
parent
e2635a0053
commit
993686b266
@ -427,6 +427,8 @@ def add_net_rt_info(net: Element, meta_info: dict):
|
|||||||
else:
|
else:
|
||||||
meta = SubElement(net, 'rt_info')
|
meta = SubElement(net, 'rt_info')
|
||||||
for key, value in meta_info.items():
|
for key, value in meta_info.items():
|
||||||
|
if isinstance(value, dict) and value == {}:
|
||||||
|
continue
|
||||||
add_meta_data_elem(meta, key, value)
|
add_meta_data_elem(meta, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,7 +2,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
|
from openvino.frontend import FrontEndManager
|
||||||
|
|
||||||
from openvino.tools.mo.convert_impl import _convert
|
from openvino.tools.mo.convert_impl import _convert
|
||||||
|
from openvino.tools.mo.utils.cli_parser import get_all_cli_parser
|
||||||
|
|
||||||
InputCutInfo = namedtuple("InputInfo", ["name", "shape", "type", "value"])
|
InputCutInfo = namedtuple("InputInfo", ["name", "shape", "type", "value"])
|
||||||
LayoutMap = namedtuple("LayoutMap", ["source_layout", "target_layout"])
|
LayoutMap = namedtuple("LayoutMap", ["source_layout", "target_layout"])
|
||||||
@ -44,4 +48,9 @@ def convert_model(input_model=None, **args):
|
|||||||
openvino.runtime.Model
|
openvino.runtime.Model
|
||||||
"""
|
"""
|
||||||
args.update({'input_model': input_model})
|
args.update({'input_model': input_model})
|
||||||
return _convert(**args)
|
|
||||||
|
cli_parser = get_all_cli_parser(FrontEndManager())
|
||||||
|
framework = None if 'framework' not in args else args['framework']
|
||||||
|
|
||||||
|
ov_model, _ = _convert(cli_parser, framework, args)
|
||||||
|
return ov_model
|
||||||
|
@ -754,28 +754,18 @@ def args_dict_to_list(cli_parser, **kwargs):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def pack_params_to_args_namespace(**kwargs):
|
def get_non_default_params(argv, cli_parser):
|
||||||
fe_manager = FrontEndManager()
|
import numbers
|
||||||
cli_parser = get_all_cli_parser(fe_manager)
|
# make dictionary with parameters which have non-default values to be serialized in IR in rt_info
|
||||||
argv = cli_parser.parse_args(args_dict_to_list(cli_parser, **kwargs))
|
|
||||||
|
|
||||||
all_params = {}
|
|
||||||
for key, value in mo_convert_params.items():
|
|
||||||
all_params.update(value)
|
|
||||||
|
|
||||||
for key, value in kwargs.items():
|
|
||||||
if key not in argv and key not in all_params.keys():
|
|
||||||
raise Error("Unrecognized argument: {}".format(key))
|
|
||||||
if value is not None:
|
|
||||||
setattr(argv, key, value)
|
|
||||||
send_params_info(argv, cli_parser)
|
|
||||||
|
|
||||||
non_default_params = {}
|
non_default_params = {}
|
||||||
for arg in vars(argv):
|
for arg, arg_value in vars(argv).items():
|
||||||
arg_value = getattr(argv, arg)
|
|
||||||
if arg_value != cli_parser.get_default(arg):
|
if arg_value != cli_parser.get_default(arg):
|
||||||
non_default_params[arg] = depersonalize(arg_value, arg)
|
value = depersonalize(arg_value, arg)
|
||||||
return argv, non_default_params
|
# Skip complex classes in params to prevent
|
||||||
|
# serializing it to rt_info
|
||||||
|
if isinstance(value, (str, bool, numbers.Number)):
|
||||||
|
non_default_params[arg] = value
|
||||||
|
return non_default_params
|
||||||
|
|
||||||
|
|
||||||
def params_to_string(**kwargs):
|
def params_to_string(**kwargs):
|
||||||
@ -833,6 +823,10 @@ def show_mo_convert_help():
|
|||||||
|
|
||||||
|
|
||||||
def input_model_is_object(argv):
|
def input_model_is_object(argv):
|
||||||
|
# Input model can be set as object only for --input_model parameter.
|
||||||
|
# --saved_model_dir or meta specific options are only used to store paths to the input model.
|
||||||
|
if 'input_model' not in argv:
|
||||||
|
return False
|
||||||
if isinstance(argv['input_model'], (str, Path)):
|
if isinstance(argv['input_model'], (str, Path)):
|
||||||
return False
|
return False
|
||||||
if argv['input_model'] is None:
|
if argv['input_model'] is None:
|
||||||
@ -840,6 +834,30 @@ def input_model_is_object(argv):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def pack_params_to_args_namespace(args: dict, cli_parser: argparse.ArgumentParser):
|
||||||
|
if len(args) > 0:
|
||||||
|
args_string = params_to_string(**args)
|
||||||
|
argv, _ = cli_parser.parse_known_args(args_dict_to_list(cli_parser, **args_string))
|
||||||
|
|
||||||
|
# get list of all available params for convert_model()
|
||||||
|
all_params = {}
|
||||||
|
for key, value in mo_convert_params.items():
|
||||||
|
all_params.update(value)
|
||||||
|
|
||||||
|
# check that there are no unknown params provided
|
||||||
|
for key, value in args_string.items():
|
||||||
|
if key not in argv and key not in all_params.keys():
|
||||||
|
raise Error("Unrecognized argument: {}".format(key))
|
||||||
|
|
||||||
|
# Non string params like input_model or extensions are ignored by parse_args()
|
||||||
|
# so we need to set them in argv separately
|
||||||
|
if value is not None and getattr(argv, key) != value:
|
||||||
|
setattr(argv, key, value)
|
||||||
|
else:
|
||||||
|
argv = cli_parser.parse_args()
|
||||||
|
return argv
|
||||||
|
|
||||||
|
|
||||||
def remove_tmp_onnx_model(out_dir):
|
def remove_tmp_onnx_model(out_dir):
|
||||||
if not os.environ.get('SAVE_TO_BYTES_IO_ONNX_MODEL'):
|
if not os.environ.get('SAVE_TO_BYTES_IO_ONNX_MODEL'):
|
||||||
tmp_onnx_model = get_onnx_temp_filename(out_dir)
|
tmp_onnx_model = get_onnx_temp_filename(out_dir)
|
||||||
@ -848,7 +866,7 @@ def remove_tmp_onnx_model(out_dir):
|
|||||||
os.remove(tmp_onnx_model)
|
os.remove(tmp_onnx_model)
|
||||||
|
|
||||||
|
|
||||||
def _convert(**args):
|
def _convert(cli_parser: argparse.ArgumentParser, framework, args):
|
||||||
if 'help' in args and args['help']:
|
if 'help' in args and args['help']:
|
||||||
show_mo_convert_help()
|
show_mo_convert_help()
|
||||||
return None
|
return None
|
||||||
@ -886,19 +904,27 @@ def _convert(**args):
|
|||||||
args['onnx_opset_version'] = None
|
args['onnx_opset_version'] = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ov_model = _convert(**args)
|
ov_model = _convert(cli_parser, framework, args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
remove_tmp_onnx_model(out_dir)
|
remove_tmp_onnx_model(out_dir)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
remove_tmp_onnx_model(out_dir)
|
remove_tmp_onnx_model(out_dir)
|
||||||
return ov_model
|
return ov_model
|
||||||
args = params_to_string(**args)
|
|
||||||
argv, non_default_params = pack_params_to_args_namespace(**args)
|
argv = pack_params_to_args_namespace(args, cli_parser)
|
||||||
|
|
||||||
|
if framework is not None:
|
||||||
|
setattr(argv, 'framework', framework)
|
||||||
|
|
||||||
|
# send telemetry with params info
|
||||||
|
send_params_info(argv, cli_parser)
|
||||||
|
|
||||||
|
non_default_params = get_non_default_params(argv, cli_parser)
|
||||||
|
|
||||||
if inp_model_is_object:
|
if inp_model_is_object:
|
||||||
argv.model_name = "model"
|
argv.model_name = "model"
|
||||||
if argv.model_name is None:
|
if not hasattr(argv, "model_name") or argv.model_name is None:
|
||||||
argv.model_name = get_model_name_from_args(argv)
|
argv.model_name = get_model_name_from_args(argv)
|
||||||
|
|
||||||
if model_framework is not None:
|
if model_framework is not None:
|
||||||
@ -929,7 +955,7 @@ def _convert(**args):
|
|||||||
telemetry.send_event('mo', 'conversion_result', 'success')
|
telemetry.send_event('mo', 'conversion_result', 'success')
|
||||||
telemetry.end_session('mo')
|
telemetry.end_session('mo')
|
||||||
telemetry.force_shutdown(1.0)
|
telemetry.force_shutdown(1.0)
|
||||||
return ov_model
|
return ov_model, argv
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
telemetry.send_event('mo', 'conversion_result', 'fail')
|
telemetry.send_event('mo', 'conversion_result', 'fail')
|
||||||
telemetry.end_session('mo')
|
telemetry.end_session('mo')
|
||||||
|
@ -11,7 +11,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
import openvino.tools.mo.utils.telemetry_stub as tm
|
import openvino.tools.mo.utils.telemetry_stub as tm
|
||||||
|
|
||||||
from openvino.tools.mo.convert import convert_model
|
from openvino.tools.mo.convert_impl import _convert
|
||||||
from openvino.tools.mo.pipeline.common import get_ir_version
|
from openvino.tools.mo.pipeline.common import get_ir_version
|
||||||
from openvino.tools.mo.utils.cli_parser import get_model_name_from_args
|
from openvino.tools.mo.utils.cli_parser import get_model_name_from_args
|
||||||
from openvino.tools.mo.utils.logger import init_logger
|
from openvino.tools.mo.utils.logger import init_logger
|
||||||
@ -32,29 +32,20 @@ def main(cli_parser: argparse.ArgumentParser, framework=None):
|
|||||||
# Initialize logger with 'ERROR' as default level to be able to form nice messages
|
# Initialize logger with 'ERROR' as default level to be able to form nice messages
|
||||||
# before arg parser deliver log_level requested by user
|
# before arg parser deliver log_level requested by user
|
||||||
init_logger('ERROR', False)
|
init_logger('ERROR', False)
|
||||||
logger = log.getLogger()
|
|
||||||
# Disable logging for parse_args() as inner convert runs parse_args() second time
|
|
||||||
# which result in duplicating of warnings
|
|
||||||
logger.disabled = True
|
|
||||||
argv = cli_parser.parse_args()
|
|
||||||
logger.disabled = False
|
|
||||||
argv.model_name = get_model_name_from_args(argv)
|
|
||||||
is_tf, _, _, _, _ = deduce_legacy_frontend_by_namespace(argv)
|
|
||||||
argv = vars(argv)
|
|
||||||
|
|
||||||
if framework is not None:
|
|
||||||
argv['framework'] = framework
|
|
||||||
|
|
||||||
ngraph_function = None
|
ngraph_function = None
|
||||||
|
argv = None
|
||||||
|
is_tf = False
|
||||||
try:
|
try:
|
||||||
ngraph_function = convert_model(**argv)
|
ngraph_function, argv = _convert(cli_parser, framework, {})
|
||||||
|
is_tf, _, _, _, _ = deduce_legacy_frontend_by_namespace(argv)
|
||||||
ov_update_message = get_ov_update_message()
|
ov_update_message = get_ov_update_message()
|
||||||
ov_api20_message = get_ov_api20_message()
|
ov_api20_message = get_ov_api20_message()
|
||||||
if ov_update_message is not None:
|
if ov_update_message is not None:
|
||||||
print(ov_update_message)
|
print(ov_update_message)
|
||||||
if ov_api20_message is not None and ngraph_function is not None:
|
if ov_api20_message is not None and ngraph_function is not None:
|
||||||
print(ov_api20_message)
|
print(ov_api20_message)
|
||||||
if argv['use_new_frontend'] and is_tf:
|
if argv.use_new_frontend and is_tf:
|
||||||
print(get_tf_fe_message())
|
print(get_tf_fe_message())
|
||||||
|
|
||||||
except (FileNotFoundError, NotADirectoryError) as e:
|
except (FileNotFoundError, NotADirectoryError) as e:
|
||||||
@ -66,7 +57,7 @@ def main(cli_parser: argparse.ArgumentParser, framework=None):
|
|||||||
for el in analysis_results.get_messages():
|
for el in analysis_results.get_messages():
|
||||||
log.error(el, extra={'analysis_info': True})
|
log.error(el, extra={'analysis_info': True})
|
||||||
log.error(err)
|
log.error(err)
|
||||||
if not argv['use_new_frontend'] and is_tf:
|
if hasattr(argv, 'use_new_frontend') and not argv.use_new_frontend and is_tf:
|
||||||
print(get_tf_fe_legacy_message())
|
print(get_tf_fe_legacy_message())
|
||||||
log.debug(traceback.format_exc())
|
log.debug(traceback.format_exc())
|
||||||
except FrameworkError as err:
|
except FrameworkError as err:
|
||||||
@ -81,21 +72,21 @@ def main(cli_parser: argparse.ArgumentParser, framework=None):
|
|||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
log.error("---------------- END OF BUG REPORT --------------")
|
log.error("---------------- END OF BUG REPORT --------------")
|
||||||
log.error("-------------------------------------------------")
|
log.error("-------------------------------------------------")
|
||||||
if not argv['use_new_frontend'] and is_tf:
|
if hasattr(argv, 'use_new_frontend') and not argv.use_new_frontend and is_tf:
|
||||||
print(get_tf_fe_legacy_message())
|
print(get_tf_fe_legacy_message())
|
||||||
|
|
||||||
if ngraph_function is None:
|
if ngraph_function is None:
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
output_dir = argv['output_dir'] if argv['output_dir'] != '.' else os.getcwd()
|
output_dir = argv.output_dir if argv.output_dir != '.' else os.getcwd()
|
||||||
model_path_no_ext = os.path.normpath(os.path.join(output_dir, argv['model_name']))
|
model_path_no_ext = os.path.normpath(os.path.join(output_dir, argv.model_name))
|
||||||
model_path = model_path_no_ext + '.xml'
|
model_path = model_path_no_ext + '.xml'
|
||||||
|
|
||||||
serialize(ngraph_function, model_path.encode('utf-8'), model_path.replace('.xml', '.bin').encode('utf-8'))
|
serialize(ngraph_function, model_path.encode('utf-8'), model_path.replace('.xml', '.bin').encode('utf-8'))
|
||||||
|
|
||||||
# generate .mapping file
|
# generate .mapping file
|
||||||
path_to_mapping = model_path_no_ext + ".mapping"
|
path_to_mapping = model_path_no_ext + ".mapping"
|
||||||
extract_names = argv['framework'] in ['tf', 'mxnet', 'kaldi']
|
extract_names = argv.framework in ['tf', 'mxnet', 'kaldi']
|
||||||
generate_mapping_file(ngraph_function, path_to_mapping, extract_names)
|
generate_mapping_file(ngraph_function, path_to_mapping, extract_names)
|
||||||
|
|
||||||
print('[ SUCCESS ] Generated IR version {} model.'.format(get_ir_version(argv)))
|
print('[ SUCCESS ] Generated IR version {} model.'.format(get_ir_version(argv)))
|
||||||
|
90
tools/mo/unit_tests/mo/convert/meta_data_test_actual.py
Normal file
90
tools/mo/unit_tests/mo/convert/meta_data_test_actual.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
# Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from openvino.runtime import get_version as get_rt_version
|
||||||
|
from openvino.runtime import serialize
|
||||||
|
|
||||||
|
from openvino.tools.mo import convert_model
|
||||||
|
from openvino.tools.mo.utils.ir_reader.restore_graph import restore_graph_from_ir, save_restored_graph
|
||||||
|
from openvino.tools.mo.utils.version import get_version
|
||||||
|
|
||||||
|
|
||||||
|
class MetaDataTestTF(unittest.TestCase):
|
||||||
|
test_directory = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_meta_data(ov_model, ref_meta):
|
||||||
|
ignore_attrs = ['version', 'optimization']
|
||||||
|
for key, value in ref_meta.items():
|
||||||
|
if key == 'conversion_parameters':
|
||||||
|
for param_name, param_value in value.items():
|
||||||
|
val = ov_model.get_rt_info([key, param_name])
|
||||||
|
if param_name in ['extensions', 'caffe_parser_path', 'input_model', 'k', 'output_dir']:
|
||||||
|
val = Path(val)
|
||||||
|
assert val == param_value, \
|
||||||
|
"Runtime info attribute with name {} does not match. Expected: {}, " \
|
||||||
|
"got {}".format(param_name, param_value, val)
|
||||||
|
continue
|
||||||
|
assert str(ov_model.get_rt_info(key)) == value, \
|
||||||
|
"Runtime info attribute with name {} does not match. Expected: {}, " \
|
||||||
|
"got {}".format(key, value, ov_model.get_rt_info(key))
|
||||||
|
|
||||||
|
for key, value in ov_model.get_rt_info().items():
|
||||||
|
if key in ignore_attrs:
|
||||||
|
continue
|
||||||
|
assert key in ref_meta, "Unexpected runtime info attribute: {}".format(key)
|
||||||
|
|
||||||
|
def test_meta_data_tf(self):
|
||||||
|
def create_tf_model():
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
tf.compat.v1.reset_default_graph()
|
||||||
|
|
||||||
|
with tf.compat.v1.Session() as sess:
|
||||||
|
inp1 = tf.compat.v1.placeholder(tf.float32, [1, 2, 3], 'Input')
|
||||||
|
inp2 = tf.compat.v1.placeholder(tf.float32, [1, 2, 3], 'Input')
|
||||||
|
relu = tf.nn.relu(inp1 + inp2, name='Relu')
|
||||||
|
|
||||||
|
output = tf.nn.sigmoid(relu, name='Sigmoid')
|
||||||
|
|
||||||
|
tf.compat.v1.global_variables_initializer()
|
||||||
|
tf_net = sess.graph_def
|
||||||
|
return tf_net
|
||||||
|
|
||||||
|
def ref_meta_data():
|
||||||
|
return {
|
||||||
|
'MO_version': get_version(),
|
||||||
|
'Runtime_version': get_rt_version(),
|
||||||
|
'legacy_frontend': "True",
|
||||||
|
'conversion_parameters': {
|
||||||
|
'scale': "1.5",
|
||||||
|
'batch': "1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory(dir=self.test_directory) as tmpdir:
|
||||||
|
|
||||||
|
model = create_tf_model()
|
||||||
|
out_xml = os.path.join(tmpdir, "model.xml")
|
||||||
|
ref_meta = ref_meta_data()
|
||||||
|
|
||||||
|
ov_model = convert_model(model, scale=1.5, batch=1)
|
||||||
|
self.check_meta_data(ov_model, ref_meta)
|
||||||
|
|
||||||
|
serialize(ov_model, out_xml.encode('utf-8'), out_xml.replace('.xml', '.bin').encode('utf-8'))
|
||||||
|
|
||||||
|
from openvino.runtime import Core
|
||||||
|
core = Core()
|
||||||
|
deserialized_model = core.read_model(out_xml)
|
||||||
|
self.check_meta_data(deserialized_model, ref_meta)
|
||||||
|
|
||||||
|
restored_graph, meta_data = restore_graph_from_ir(out_xml, out_xml.replace('.xml', '.bin'))
|
||||||
|
save_restored_graph(restored_graph, tmpdir, meta_data, "mo_ir_reader_test_model")
|
||||||
|
|
||||||
|
mo_ir_reader_test_model = core.read_model(os.path.join(tmpdir, "mo_ir_reader_test_model.xml"))
|
||||||
|
self.check_meta_data(mo_ir_reader_test_model, ref_meta)
|
@ -76,6 +76,15 @@ def test_main_error_log():
|
|||||||
assert test_log == ref_log
|
assert test_log == ref_log
|
||||||
|
|
||||||
|
|
||||||
|
def test_rt_info():
|
||||||
|
setup_env()
|
||||||
|
args = [sys.executable, '-m', 'pytest',
|
||||||
|
os.path.join(os.path.dirname(__file__), 'convert/meta_data_test_actual.py'), '-s']
|
||||||
|
|
||||||
|
status = subprocess.run(args, env=os.environ, capture_output=True)
|
||||||
|
assert not status.returncode
|
||||||
|
|
||||||
|
|
||||||
def test_mo_extensions_test():
|
def test_mo_extensions_test():
|
||||||
setup_env()
|
setup_env()
|
||||||
args = [sys.executable, '-m', 'pytest',
|
args = [sys.executable, '-m', 'pytest',
|
||||||
|
@ -2,7 +2,9 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import numpy
|
||||||
import os
|
import os
|
||||||
|
import pathlib
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
@ -14,10 +16,13 @@ import numpy as np
|
|||||||
from openvino.tools.mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_mean_scale_dictionary, \
|
from openvino.tools.mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_mean_scale_dictionary, \
|
||||||
get_model_name, \
|
get_model_name, \
|
||||||
parse_tuple_pairs, check_positive, writable_dir, readable_dirs, \
|
parse_tuple_pairs, check_positive, writable_dir, readable_dirs, \
|
||||||
readable_file, get_freeze_placeholder_values, parse_transform, check_available_transforms, get_layout_values, get_data_type_from_input_value
|
readable_file, get_freeze_placeholder_values, parse_transform, check_available_transforms, get_layout_values, get_data_type_from_input_value, get_all_cli_parser
|
||||||
|
from openvino.tools.mo.convert_impl import pack_params_to_args_namespace
|
||||||
|
from openvino.tools.mo.convert import InputCutInfo, LayoutMap
|
||||||
from openvino.tools.mo.utils.error import Error
|
from openvino.tools.mo.utils.error import Error
|
||||||
from unit_tests.mo.unit_test_with_mocked_telemetry import UnitTestWithMockedTelemetry
|
from unit_tests.mo.unit_test_with_mocked_telemetry import UnitTestWithMockedTelemetry
|
||||||
from openvino.runtime import PartialShape, Dimension
|
from openvino.runtime import PartialShape, Dimension, Layout
|
||||||
|
from openvino.frontend import FrontEndManager
|
||||||
|
|
||||||
|
|
||||||
class TestingMeanScaleGetter(UnitTestWithMockedTelemetry):
|
class TestingMeanScaleGetter(UnitTestWithMockedTelemetry):
|
||||||
@ -1955,3 +1960,56 @@ class TestLayoutParsingEmptyNamesNoBrackets(unittest.TestCase):
|
|||||||
def wrong_case_3(self):
|
def wrong_case_3(self):
|
||||||
argv_source_layout = "nchv->"
|
argv_source_layout = "nchv->"
|
||||||
self.assertRaises(get_layout_values(argv_source_layout=argv_source_layout))
|
self.assertRaises(get_layout_values(argv_source_layout=argv_source_layout))
|
||||||
|
|
||||||
|
class TestPackParamsToArgsNamespace(unittest.TestCase):
|
||||||
|
def test_mo_convert_params(self):
|
||||||
|
from openvino.frontend import ConversionExtension
|
||||||
|
args = {'input_model': os.path.dirname(__file__),
|
||||||
|
'input_shape': [PartialShape([1,100,100,3]), [2,3]],
|
||||||
|
'extensions': ConversionExtension("Ext", lambda x: x),
|
||||||
|
'reverse_input_channels': True,
|
||||||
|
'scale': 0.5,
|
||||||
|
'input': ['name', InputCutInfo("a", [1,2,3], numpy.float32, [5, 6, 7])],
|
||||||
|
'batch': 1,
|
||||||
|
'output': ["a", "b", "c"],
|
||||||
|
'mean_values': [0.5, 0.3],
|
||||||
|
'scale_values': {"a": np.array([0.4]), "b": [0.5, 0.6]},
|
||||||
|
'source_layout': Layout("nchw"),
|
||||||
|
'layout': {"a": LayoutMap("nchw","nhwc"), "b": "nc"},
|
||||||
|
'transform': ('LowLatency2', {'use_const_initializer': False})}
|
||||||
|
|
||||||
|
cli_parser = get_all_cli_parser(FrontEndManager())
|
||||||
|
argv = pack_params_to_args_namespace(args, cli_parser)
|
||||||
|
|
||||||
|
assert argv.input_model == args['input_model']
|
||||||
|
assert argv.extensions == [args['extensions']]
|
||||||
|
assert argv.reverse_input_channels == args['reverse_input_channels']
|
||||||
|
assert argv.scale == 0.5
|
||||||
|
assert argv.batch == 1
|
||||||
|
assert argv.input_shape == "[1,100,100,3],[2,3]"
|
||||||
|
assert argv.input == "name,a[1 2 3]{f32}->[5 6 7]"
|
||||||
|
assert argv.output == "a,b,c"
|
||||||
|
assert argv.mean_values == "[0.5,0.3]"
|
||||||
|
assert argv.scale_values == "a[0.4],b[0.5,0.6]"
|
||||||
|
assert argv.source_layout == "[N,C,H,W]"
|
||||||
|
assert argv.layout == "a(nchw->nhwc),b(nc)"
|
||||||
|
assert argv.transform == "LowLatency2[use_const_initializer=False]"
|
||||||
|
|
||||||
|
for arg, value in vars(argv).items():
|
||||||
|
if arg not in args:
|
||||||
|
assert value == cli_parser.get_default(arg)
|
||||||
|
|
||||||
|
def test_not_existing_dir(self):
|
||||||
|
args = {"input_model": "abc"}
|
||||||
|
cli_parser = get_all_cli_parser(FrontEndManager())
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(Error, "The \"abc\" is not existing file or directory"):
|
||||||
|
pack_params_to_args_namespace(args, cli_parser)
|
||||||
|
|
||||||
|
def test_unknown_params(self):
|
||||||
|
args = {"input_model": os.path.dirname(__file__),
|
||||||
|
"a": "b"}
|
||||||
|
cli_parser = get_all_cli_parser(FrontEndManager())
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(Error, "Unrecognized argument: a"):
|
||||||
|
pack_params_to_args_namespace(args, cli_parser)
|
||||||
|
Loading…
Reference in New Issue
Block a user