Change the way Model Optimizer loads frontends (#7330)
* Adds two switches `use_new_frontend` and `use_legacy_frontend` to override defaults. * Rename ONNX frontend from `onnx_experimental` to `onnx`
This commit is contained in:
@@ -44,7 +44,7 @@ from mo.utils.version import get_version, get_simplified_mo_version, get_simplif
|
||||
from mo.utils.versions_checker import check_requirements # pylint: disable=no-name-in-module
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from ngraph.frontend import FrontEndManager
|
||||
from ngraph.frontend import FrontEndManager, FrontEnd
|
||||
|
||||
|
||||
def replace_ext(name: str, old: str, new: str):
|
||||
@@ -98,24 +98,33 @@ def print_argv(argv: argparse.Namespace, is_caffe: bool, is_tf: bool, is_mxnet:
|
||||
|
||||
def get_moc_frontends(argv: argparse.Namespace):
|
||||
fem = argv.feManager
|
||||
available_moc_front_ends = []
|
||||
moc_front_end = None
|
||||
|
||||
# TODO: in future, check of 'use_legacy_frontend' in argv can be added here (issue 61973)
|
||||
force_use_legacy_frontend = False
|
||||
# Read user flags:
|
||||
use_legacy_frontend = argv.use_legacy_frontend
|
||||
use_new_frontend = argv.use_new_frontend
|
||||
|
||||
if fem and not force_use_legacy_frontend:
|
||||
available_moc_front_ends = fem.get_available_front_ends()
|
||||
if argv.input_model:
|
||||
if not argv.framework:
|
||||
moc_front_end = fem.load_by_model(argv.input_model)
|
||||
# skip onnx frontend as not fully supported yet (63050)
|
||||
if moc_front_end and moc_front_end.get_name() == "onnx":
|
||||
moc_front_end = None
|
||||
if moc_front_end:
|
||||
argv.framework = moc_front_end.get_name()
|
||||
elif argv.framework in available_moc_front_ends:
|
||||
moc_front_end = fem.load_by_framework(argv.framework)
|
||||
if not fem or use_legacy_frontend:
|
||||
return None, []
|
||||
|
||||
available_moc_front_ends = fem.get_available_front_ends()
|
||||
|
||||
if not argv.framework and argv.input_model:
|
||||
moc_front_end = fem.load_by_model(argv.input_model)
|
||||
if not moc_front_end:
|
||||
return None, available_moc_front_ends
|
||||
argv.framework = moc_front_end.get_name()
|
||||
elif argv.framework in available_moc_front_ends:
|
||||
moc_front_end = fem.load_by_framework(argv.framework)
|
||||
else:
|
||||
return None, []
|
||||
|
||||
# Set which frontend to use by default, values should be 'new' or 'legacy'
|
||||
frontend_defaults = {
|
||||
'onnx': 'legacy',
|
||||
}
|
||||
# Disable MOC frontend if default is set to legacy and no user override
|
||||
if frontend_defaults.get(moc_front_end.get_name()) == 'legacy' and not use_new_frontend:
|
||||
moc_front_end = None
|
||||
|
||||
return moc_front_end, available_moc_front_ends
|
||||
|
||||
@@ -130,8 +139,13 @@ def arguments_post_parsing(argv: argparse.Namespace):
|
||||
frameworks = ['tf', 'caffe', 'mxnet', 'kaldi', 'onnx']
|
||||
frameworks = list(set(frameworks + available_moc_front_ends))
|
||||
if argv.framework not in frameworks:
|
||||
raise Error('Framework {} is not a valid target. Please use --framework with one from the list: {}. ' +
|
||||
refer_to_faq_msg(15), argv.framework, frameworks)
|
||||
if argv.use_legacy_frontend:
|
||||
raise Error('Framework {} is not a valid target when using the --use_legacy_frontend flag. '
|
||||
'The following legacy frameworks are available: {}' +
|
||||
refer_to_faq_msg(15), argv.framework, frameworks)
|
||||
else:
|
||||
raise Error('Framework {} is not a valid target. Please use --framework with one from the list: {}. ' +
|
||||
refer_to_faq_msg(15), argv.framework, frameworks)
|
||||
|
||||
if is_tf and not argv.input_model and not argv.saved_model_dir and not argv.input_meta_graph:
|
||||
raise Error('Path to input model or saved model dir is required: use --input_model, --saved_model_dir or '
|
||||
@@ -195,11 +209,13 @@ def arguments_post_parsing(argv: argparse.Namespace):
|
||||
if argv.legacy_ir_generation and len(argv.transform) != 0:
|
||||
raise Error("--legacy_ir_generation and --transform keys can not be used at the same time.")
|
||||
|
||||
use_legacy_fe = argv.framework not in available_moc_front_ends
|
||||
# For C++ frontends there is no specific python installation requirements, thus check only generic ones
|
||||
ret_code = check_requirements(framework=argv.framework if use_legacy_fe else None)
|
||||
# For C++ frontends there are no specific Python installation requirements, check only generic ones
|
||||
if moc_front_end:
|
||||
ret_code = check_requirements()
|
||||
else:
|
||||
ret_code = check_requirements(framework=argv.framework)
|
||||
if ret_code:
|
||||
raise Error('check_requirements exit with return code {}'.format(ret_code))
|
||||
raise Error('check_requirements exited with return code {}'.format(ret_code))
|
||||
|
||||
if is_tf and argv.tensorflow_use_custom_operations_config is not None:
|
||||
argv.transformations_config = argv.tensorflow_use_custom_operations_config
|
||||
@@ -295,10 +311,11 @@ def prepare_ir(argv):
|
||||
ngraph_function = None
|
||||
moc_front_end, available_moc_front_ends = get_moc_frontends(argv)
|
||||
|
||||
if argv.framework not in available_moc_front_ends:
|
||||
graph = unified_pipeline(argv)
|
||||
else:
|
||||
if moc_front_end:
|
||||
ngraph_function = moc_pipeline(argv, moc_front_end)
|
||||
else:
|
||||
graph = unified_pipeline(argv)
|
||||
|
||||
return graph, ngraph_function
|
||||
|
||||
|
||||
|
||||
@@ -7,4 +7,6 @@ from mo.utils.cli_parser import get_onnx_cli_parser
|
||||
|
||||
if __name__ == "__main__":
|
||||
from mo.main import main
|
||||
sys.exit(main(get_onnx_cli_parser(), None, 'onnx'))
|
||||
from ngraph.frontend import FrontEndManager # pylint: disable=no-name-in-module,import-error
|
||||
|
||||
sys.exit(main(get_onnx_cli_parser(), FrontEndManager(), 'onnx'))
|
||||
|
||||
@@ -19,6 +19,7 @@ from mo.utils.error import Error
|
||||
from mo.utils.utils import refer_to_faq_msg
|
||||
from mo.utils.version import get_version
|
||||
|
||||
|
||||
class DeprecatedStoreTrue(argparse.Action):
|
||||
def __init__(self, nargs=0, **kw):
|
||||
super().__init__(nargs=nargs, **kw)
|
||||
@@ -353,10 +354,16 @@ def get_common_cli_parser(parser: argparse.ArgumentParser = None):
|
||||
help='Switch model conversion progress display to a multiline mode.',
|
||||
action='store_true', default=False)
|
||||
common_group.add_argument('--transformations_config',
|
||||
help='Use the configuration file with transformations description.',
|
||||
action=CanonicalizePathCheckExistenceAction)
|
||||
help='Use the configuration file with transformations description.',
|
||||
action=CanonicalizePathCheckExistenceAction)
|
||||
common_group.add_argument('--legacy_ir_generation',
|
||||
help=argparse.SUPPRESS, action=DeprecatedStoreTrue, default=False)
|
||||
common_group.add_argument("--use_new_frontend",
|
||||
help="Use new frontend API for model processing",
|
||||
action='store_true', default=False)
|
||||
common_group.add_argument("--use_legacy_frontend",
|
||||
help="Use legacy API for model processing",
|
||||
action='store_true', default=False)
|
||||
return parser
|
||||
|
||||
|
||||
@@ -378,6 +385,7 @@ def get_common_cli_options(model_name):
|
||||
d['disable_gfusing'] = ['- Enable grouped convolutions fusing', lambda x: not x]
|
||||
d['move_to_preprocess'] = '- Move mean values to preprocess section'
|
||||
d['reverse_input_channels'] = '- Reverse input channels'
|
||||
d['use_legacy_frontend'] = '- Use legacy API for model processing'
|
||||
return d
|
||||
|
||||
|
||||
|
||||
@@ -42,3 +42,4 @@ def guess_framework_by_ext(input_model_path: str) -> int:
|
||||
return 'kaldi'
|
||||
elif re.match(r'^.*\.onnx$', input_model_path):
|
||||
return 'onnx'
|
||||
|
||||
|
||||
@@ -74,6 +74,8 @@ def replaceArgsHelper(log_level='DEBUG',
|
||||
scale_values=scale_values,
|
||||
output_dir=output_dir,
|
||||
freeze_placeholder_with_value=freeze_placeholder_with_value,
|
||||
use_legacy_frontend=None,
|
||||
use_new_frontend=None,
|
||||
framework=None)
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ extern "C" ONNX_FRONTEND_API FrontEndVersion GetAPIVersion() {
|
||||
|
||||
extern "C" ONNX_FRONTEND_API void* GetFrontEndData() {
|
||||
FrontEndPluginInfo* res = new FrontEndPluginInfo();
|
||||
res->m_name = "onnx_experimental";
|
||||
res->m_name = "onnx";
|
||||
res->m_creator = []() {
|
||||
return std::make_shared<FrontEndONNX>();
|
||||
};
|
||||
|
||||
@@ -6,4 +6,4 @@
|
||||
|
||||
#include <string>
|
||||
|
||||
static const std::string ONNX_FE = "onnx_experimental";
|
||||
static const std::string ONNX_FE = "onnx";
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# external weights are not supported for ONNX Frontend
|
||||
ONNXLoadTest/FrontEndLoadFromTest.testLoadFromTwoFiles/onnx_experimental
|
||||
ONNXLoadTest/FrontEndLoadFromTest.testLoadFromTwoStreams/onnx_experimental
|
||||
ONNXLoadTest/FrontEndLoadFromTest.testLoadFromTwoFiles/onnx
|
||||
ONNXLoadTest/FrontEndLoadFromTest.testLoadFromTwoStreams/onnx
|
||||
|
||||
@@ -136,4 +136,8 @@ void regclass_pyngraph_FrontEnd(py::module m) {
|
||||
get_name : str
|
||||
Current frontend name. Empty string if not implemented.
|
||||
)");
|
||||
|
||||
fem.def("__repr__", [](const ngraph::frontend::FrontEnd& self) -> std::string {
|
||||
return "<FrontEnd '" + self.get_name() + "'>";
|
||||
});
|
||||
}
|
||||
|
||||
@@ -78,6 +78,10 @@ void regclass_pyngraph_FrontEndManager(py::module m) {
|
||||
load_by_model : FrontEnd
|
||||
Frontend interface for further loading of models. 'None' if no suitable frontend is found
|
||||
)");
|
||||
|
||||
fem.def("__repr__", [](const ngraph::frontend::FrontEndManager& self) -> std::string {
|
||||
return "<FrontEndManager>";
|
||||
});
|
||||
}
|
||||
|
||||
void regclass_pyngraph_GeneralFailureFrontEnd(py::module m) {
|
||||
|
||||
@@ -37,7 +37,7 @@ def run_function(function, *inputs, expected):
|
||||
|
||||
fem = FrontEndManager()
|
||||
onnx_model_filename = "model.onnx"
|
||||
ONNX_FRONTEND_NAME = "onnx_experimental"
|
||||
ONNX_FRONTEND_NAME = "onnx"
|
||||
|
||||
|
||||
def setup_module():
|
||||
|
||||
@@ -217,7 +217,7 @@ def create_test_onnx_models():
|
||||
|
||||
fem = FrontEndManager()
|
||||
test_models_names = []
|
||||
ONNX_FRONTEND_NAME = "onnx_experimental"
|
||||
ONNX_FRONTEND_NAME = "onnx"
|
||||
|
||||
|
||||
def setup_module():
|
||||
|
||||
Reference in New Issue
Block a user