Change the way Model Optimizer loads frontends (#7330)

* Adds two switches `use_new_frontend` and `use_legacy_frontend` to override defaults.
* Rename ONNX frontend from `onnx_experimental` to `onnx`
This commit is contained in:
Michał Karzyński
2021-09-30 12:17:36 +02:00
committed by GitHub
parent f3c8f2bc49
commit bd5b1bf99f
12 changed files with 73 additions and 35 deletions

View File

@@ -44,7 +44,7 @@ from mo.utils.version import get_version, get_simplified_mo_version, get_simplif
from mo.utils.versions_checker import check_requirements # pylint: disable=no-name-in-module
# pylint: disable=no-name-in-module,import-error
from ngraph.frontend import FrontEndManager
from ngraph.frontend import FrontEndManager, FrontEnd
def replace_ext(name: str, old: str, new: str):
@@ -98,24 +98,33 @@ def print_argv(argv: argparse.Namespace, is_caffe: bool, is_tf: bool, is_mxnet:
def get_moc_frontends(argv: argparse.Namespace):
fem = argv.feManager
available_moc_front_ends = []
moc_front_end = None
# TODO: in future, check of 'use_legacy_frontend' in argv can be added here (issue 61973)
force_use_legacy_frontend = False
# Read user flags:
use_legacy_frontend = argv.use_legacy_frontend
use_new_frontend = argv.use_new_frontend
if fem and not force_use_legacy_frontend:
available_moc_front_ends = fem.get_available_front_ends()
if argv.input_model:
if not argv.framework:
moc_front_end = fem.load_by_model(argv.input_model)
# skip onnx frontend as not fully supported yet (63050)
if moc_front_end and moc_front_end.get_name() == "onnx":
moc_front_end = None
if moc_front_end:
argv.framework = moc_front_end.get_name()
elif argv.framework in available_moc_front_ends:
moc_front_end = fem.load_by_framework(argv.framework)
if not fem or use_legacy_frontend:
return None, []
available_moc_front_ends = fem.get_available_front_ends()
if not argv.framework and argv.input_model:
moc_front_end = fem.load_by_model(argv.input_model)
if not moc_front_end:
return None, available_moc_front_ends
argv.framework = moc_front_end.get_name()
elif argv.framework in available_moc_front_ends:
moc_front_end = fem.load_by_framework(argv.framework)
else:
return None, []
# Set which frontend to use by default, values should be 'new' or 'legacy'
frontend_defaults = {
'onnx': 'legacy',
}
# Disable MOC frontend if default is set to legacy and no user override
if frontend_defaults.get(moc_front_end.get_name()) == 'legacy' and not use_new_frontend:
moc_front_end = None
return moc_front_end, available_moc_front_ends
@@ -130,8 +139,13 @@ def arguments_post_parsing(argv: argparse.Namespace):
frameworks = ['tf', 'caffe', 'mxnet', 'kaldi', 'onnx']
frameworks = list(set(frameworks + available_moc_front_ends))
if argv.framework not in frameworks:
raise Error('Framework {} is not a valid target. Please use --framework with one from the list: {}. ' +
refer_to_faq_msg(15), argv.framework, frameworks)
if argv.use_legacy_frontend:
raise Error('Framework {} is not a valid target when using the --use_legacy_frontend flag. '
'The following legacy frameworks are available: {}' +
refer_to_faq_msg(15), argv.framework, frameworks)
else:
raise Error('Framework {} is not a valid target. Please use --framework with one from the list: {}. ' +
refer_to_faq_msg(15), argv.framework, frameworks)
if is_tf and not argv.input_model and not argv.saved_model_dir and not argv.input_meta_graph:
raise Error('Path to input model or saved model dir is required: use --input_model, --saved_model_dir or '
@@ -195,11 +209,13 @@ def arguments_post_parsing(argv: argparse.Namespace):
if argv.legacy_ir_generation and len(argv.transform) != 0:
raise Error("--legacy_ir_generation and --transform keys can not be used at the same time.")
use_legacy_fe = argv.framework not in available_moc_front_ends
# For C++ frontends there is no specific python installation requirements, thus check only generic ones
ret_code = check_requirements(framework=argv.framework if use_legacy_fe else None)
# For C++ frontends there are no specific Python installation requirements, check only generic ones
if moc_front_end:
ret_code = check_requirements()
else:
ret_code = check_requirements(framework=argv.framework)
if ret_code:
raise Error('check_requirements exit with return code {}'.format(ret_code))
raise Error('check_requirements exited with return code {}'.format(ret_code))
if is_tf and argv.tensorflow_use_custom_operations_config is not None:
argv.transformations_config = argv.tensorflow_use_custom_operations_config
@@ -295,10 +311,11 @@ def prepare_ir(argv):
ngraph_function = None
moc_front_end, available_moc_front_ends = get_moc_frontends(argv)
if argv.framework not in available_moc_front_ends:
graph = unified_pipeline(argv)
else:
if moc_front_end:
ngraph_function = moc_pipeline(argv, moc_front_end)
else:
graph = unified_pipeline(argv)
return graph, ngraph_function

View File

@@ -7,4 +7,6 @@ from mo.utils.cli_parser import get_onnx_cli_parser
if __name__ == "__main__":
from mo.main import main
sys.exit(main(get_onnx_cli_parser(), None, 'onnx'))
from ngraph.frontend import FrontEndManager # pylint: disable=no-name-in-module,import-error
sys.exit(main(get_onnx_cli_parser(), FrontEndManager(), 'onnx'))

View File

@@ -19,6 +19,7 @@ from mo.utils.error import Error
from mo.utils.utils import refer_to_faq_msg
from mo.utils.version import get_version
class DeprecatedStoreTrue(argparse.Action):
def __init__(self, nargs=0, **kw):
super().__init__(nargs=nargs, **kw)
@@ -353,10 +354,16 @@ def get_common_cli_parser(parser: argparse.ArgumentParser = None):
help='Switch model conversion progress display to a multiline mode.',
action='store_true', default=False)
common_group.add_argument('--transformations_config',
help='Use the configuration file with transformations description.',
action=CanonicalizePathCheckExistenceAction)
help='Use the configuration file with transformations description.',
action=CanonicalizePathCheckExistenceAction)
common_group.add_argument('--legacy_ir_generation',
help=argparse.SUPPRESS, action=DeprecatedStoreTrue, default=False)
common_group.add_argument("--use_new_frontend",
help="Use new frontend API for model processing",
action='store_true', default=False)
common_group.add_argument("--use_legacy_frontend",
help="Use legacy API for model processing",
action='store_true', default=False)
return parser
@@ -378,6 +385,7 @@ def get_common_cli_options(model_name):
d['disable_gfusing'] = ['- Enable grouped convolutions fusing', lambda x: not x]
d['move_to_preprocess'] = '- Move mean values to preprocess section'
d['reverse_input_channels'] = '- Reverse input channels'
d['use_legacy_frontend'] = '- Use legacy API for model processing'
return d

View File

@@ -42,3 +42,4 @@ def guess_framework_by_ext(input_model_path: str) -> int:
return 'kaldi'
elif re.match(r'^.*\.onnx$', input_model_path):
return 'onnx'

View File

@@ -74,6 +74,8 @@ def replaceArgsHelper(log_level='DEBUG',
scale_values=scale_values,
output_dir=output_dir,
freeze_placeholder_with_value=freeze_placeholder_with_value,
use_legacy_frontend=None,
use_new_frontend=None,
framework=None)

View File

@@ -26,7 +26,7 @@ extern "C" ONNX_FRONTEND_API FrontEndVersion GetAPIVersion() {
extern "C" ONNX_FRONTEND_API void* GetFrontEndData() {
FrontEndPluginInfo* res = new FrontEndPluginInfo();
res->m_name = "onnx_experimental";
res->m_name = "onnx";
res->m_creator = []() {
return std::make_shared<FrontEndONNX>();
};

View File

@@ -6,4 +6,4 @@
#include <string>
static const std::string ONNX_FE = "onnx_experimental";
static const std::string ONNX_FE = "onnx";

View File

@@ -1,3 +1,3 @@
# external weights are not supported for ONNX Frontend
ONNXLoadTest/FrontEndLoadFromTest.testLoadFromTwoFiles/onnx_experimental
ONNXLoadTest/FrontEndLoadFromTest.testLoadFromTwoStreams/onnx_experimental
ONNXLoadTest/FrontEndLoadFromTest.testLoadFromTwoFiles/onnx
ONNXLoadTest/FrontEndLoadFromTest.testLoadFromTwoStreams/onnx

View File

@@ -136,4 +136,8 @@ void regclass_pyngraph_FrontEnd(py::module m) {
get_name : str
Current frontend name. Empty string if not implemented.
)");
fem.def("__repr__", [](const ngraph::frontend::FrontEnd& self) -> std::string {
return "<FrontEnd '" + self.get_name() + "'>";
});
}

View File

@@ -78,6 +78,10 @@ void regclass_pyngraph_FrontEndManager(py::module m) {
load_by_model : FrontEnd
Frontend interface for further loading of models. 'None' if no suitable frontend is found
)");
fem.def("__repr__", [](const ngraph::frontend::FrontEndManager& self) -> std::string {
return "<FrontEndManager>";
});
}
void regclass_pyngraph_GeneralFailureFrontEnd(py::module m) {

View File

@@ -37,7 +37,7 @@ def run_function(function, *inputs, expected):
fem = FrontEndManager()
onnx_model_filename = "model.onnx"
ONNX_FRONTEND_NAME = "onnx_experimental"
ONNX_FRONTEND_NAME = "onnx"
def setup_module():

View File

@@ -217,7 +217,7 @@ def create_test_onnx_models():
fem = FrontEndManager()
test_models_names = []
ONNX_FRONTEND_NAME = "onnx_experimental"
ONNX_FRONTEND_NAME = "onnx"
def setup_module():