diff --git a/model-optimizer/mo/back/offline_transformations.py b/model-optimizer/mo/back/offline_transformations.py index 1df5f6cb7a0..e97b9826872 100644 --- a/model-optimizer/mo/back/offline_transformations.py +++ b/model-optimizer/mo/back/offline_transformations.py @@ -17,16 +17,9 @@ def get_available_transformations(): return {} -def apply_offline_transformations(input_model: str, framework: str, transforms: list): - # This variable is only needed by GenerateMappingFile transformation - # to produce correct mapping - extract_names = framework in ['tf', 'mxnet', 'kaldi'] - - from openvino.inference_engine import read_network # pylint: disable=import-error,no-name-in-module - from openvino.offline_transformations import ApplyMOCTransformations, GenerateMappingFile # pylint: disable=import-error,no-name-in-module - - net = read_network(input_model + "_tmp.xml", input_model + "_tmp.bin") - +# net should be openvino.inference_engine.IENetwork type, but IE Engine is still optional dependency +def apply_moc_transformations(net: object, transforms: list): + from openvino.offline_transformations import ApplyMOCTransformations # pylint: disable=import-error,no-name-in-module available_transformations = get_available_transformations() for name, args in transforms: @@ -36,6 +29,18 @@ def apply_offline_transformations(input_model: str, framework: str, transforms: available_transformations[name](net, **args) ApplyMOCTransformations(net, False) + + +def apply_offline_transformations(input_model: str, framework: str, transforms: list): + # This variable is only needed by GenerateMappingFile transformation + # to produce correct mapping + extract_names = framework in ['tf', 'mxnet', 'kaldi'] + + from openvino.inference_engine import read_network # pylint: disable=import-error,no-name-in-module + from openvino.offline_transformations import GenerateMappingFile # pylint: disable=import-error,no-name-in-module + + net = read_network(input_model + "_tmp.xml", input_model + "_tmp.bin") + apply_moc_transformations(net, transforms) net.serialize(input_model + ".xml", input_model + ".bin") path_to_mapping = input_model + ".mapping" GenerateMappingFile(net, path_to_mapping.encode('utf-8'), extract_names) diff --git a/model-optimizer/mo/moc_frontend/serialize.py b/model-optimizer/mo/moc_frontend/serialize.py index 86433c6e36d..dd2533c6702 100644 --- a/model-optimizer/mo/moc_frontend/serialize.py +++ b/model-optimizer/mo/moc_frontend/serialize.py @@ -5,7 +5,7 @@ import argparse import os from mo.pipeline.common import get_ir_version from mo.back.ie_ir_ver_2.emitter import append_ir_info -from mo.utils.cli_parser import get_meta_info +from mo.utils.cli_parser import get_meta_info, parse_transform from ngraph import Function # pylint: disable=no-name-in-module,import-error from ngraph import function_to_cnn # pylint: disable=no-name-in-module,import-error @@ -15,6 +15,8 @@ def moc_emit_ir(ngraph_function: Function, argv: argparse.Namespace): output_dir = argv.output_dir if argv.output_dir != '.' else os.getcwd() network = function_to_cnn(ngraph_function) + from mo.back.offline_transformations import apply_moc_transformations + apply_moc_transformations(network, parse_transform(argv.transform)) orig_model_name = os.path.normpath(os.path.join(output_dir, argv.model_name)) network.serialize(orig_model_name + ".xml", orig_model_name + ".bin")