takeaway parsing arguments in separate func (#7266)

This commit is contained in:
Ruslan Nugmanov 2021-09-21 14:45:05 +03:00 committed by GitHub
parent db385569c2
commit 5ad2400468
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -96,7 +96,7 @@ def print_argv(argv: argparse.Namespace, is_caffe: bool, is_tf: bool, is_mxnet:
print('\n'.join(lines), flush=True) print('\n'.join(lines), flush=True)
def prepare_ir(argv: argparse.Namespace): def get_moc_frontends(argv: argparse.Namespace):
fem = argv.feManager fem = argv.feManager
available_moc_front_ends = [] available_moc_front_ends = []
moc_front_end = None moc_front_end = None
@ -117,6 +117,12 @@ def prepare_ir(argv: argparse.Namespace):
elif argv.framework in available_moc_front_ends: elif argv.framework in available_moc_front_ends:
moc_front_end = fem.load_by_framework(argv.framework) moc_front_end = fem.load_by_framework(argv.framework)
return moc_front_end, available_moc_front_ends
def arguments_post_parsing(argv: argparse.Namespace):
moc_front_end, available_moc_front_ends = get_moc_frontends(argv)
is_tf, is_caffe, is_mxnet, is_kaldi, is_onnx =\ is_tf, is_caffe, is_mxnet, is_kaldi, is_onnx =\
deduce_framework_by_namespace(argv) if not moc_front_end else [False, False, False, False, False] deduce_framework_by_namespace(argv) if not moc_front_end else [False, False, False, False, False]
@ -279,8 +285,15 @@ def prepare_ir(argv: argparse.Namespace):
from mo.front.onnx.register_custom_ops import get_front_classes from mo.front.onnx.register_custom_ops import get_front_classes
import_extensions.load_dirs(argv.framework, extensions, get_front_classes) import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
return argv
def prepare_ir(argv):
argv = arguments_post_parsing(argv)
graph = None graph = None
ngraph_function = None ngraph_function = None
moc_front_end, available_moc_front_ends = get_moc_frontends(argv)
if argv.framework not in available_moc_front_ends: if argv.framework not in available_moc_front_ends:
graph = unified_pipeline(argv) graph = unified_pipeline(argv)