diff --git a/model-optimizer/mo/main.py b/model-optimizer/mo/main.py index 3a9553a2bc0..f3998be279e 100644 --- a/model-optimizer/mo/main.py +++ b/model-optimizer/mo/main.py @@ -96,7 +96,7 @@ def print_argv(argv: argparse.Namespace, is_caffe: bool, is_tf: bool, is_mxnet: print('\n'.join(lines), flush=True) -def prepare_ir(argv: argparse.Namespace): +def get_moc_frontends(argv: argparse.Namespace): fem = argv.feManager available_moc_front_ends = [] moc_front_end = None @@ -117,6 +117,12 @@ def prepare_ir(argv: argparse.Namespace): elif argv.framework in available_moc_front_ends: moc_front_end = fem.load_by_framework(argv.framework) + return moc_front_end, available_moc_front_ends + + +def arguments_post_parsing(argv: argparse.Namespace): + moc_front_end, available_moc_front_ends = get_moc_frontends(argv) + is_tf, is_caffe, is_mxnet, is_kaldi, is_onnx =\ deduce_framework_by_namespace(argv) if not moc_front_end else [False, False, False, False, False] @@ -279,8 +285,15 @@ def prepare_ir(argv: argparse.Namespace): from mo.front.onnx.register_custom_ops import get_front_classes import_extensions.load_dirs(argv.framework, extensions, get_front_classes) + return argv + + +def prepare_ir(argv): + argv = arguments_post_parsing(argv) + graph = None ngraph_function = None + moc_front_end, available_moc_front_ends = get_moc_frontends(argv) if argv.framework not in available_moc_front_ends: graph = unified_pipeline(argv)