Apply moc transformations on FE API path (#6871)

* apply moc transformations

* changed type of net in apply_moc_transformations

* review remarks

* args->argv typo
This commit is contained in:
Mateusz Bencer 2021-08-05 17:08:40 +02:00 committed by GitHub
parent 79e9190838
commit e4dfff387b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 11 deletions

View File

@ -17,16 +17,9 @@ def get_available_transformations():
return {}
def apply_offline_transformations(input_model: str, framework: str, transforms: list):
# This variable is only needed by GenerateMappingFile transformation
# to produce correct mapping
extract_names = framework in ['tf', 'mxnet', 'kaldi']
from openvino.inference_engine import read_network # pylint: disable=import-error,no-name-in-module
from openvino.offline_transformations import ApplyMOCTransformations, GenerateMappingFile # pylint: disable=import-error,no-name-in-module
net = read_network(input_model + "_tmp.xml", input_model + "_tmp.bin")
# net should be openvino.inference_engine.IENetwork type, but IE Engine is still optional dependency
def apply_moc_transformations(net: object, transforms: list):
from openvino.offline_transformations import ApplyMOCTransformations # pylint: disable=import-error,no-name-in-module
available_transformations = get_available_transformations()
for name, args in transforms:
@ -36,6 +29,18 @@ def apply_offline_transformations(input_model: str, framework: str, transforms:
available_transformations[name](net, **args)
ApplyMOCTransformations(net, False)
def apply_offline_transformations(input_model: str, framework: str, transforms: list):
# This variable is only needed by GenerateMappingFile transformation
# to produce correct mapping
extract_names = framework in ['tf', 'mxnet', 'kaldi']
from openvino.inference_engine import read_network # pylint: disable=import-error,no-name-in-module
from openvino.offline_transformations import GenerateMappingFile # pylint: disable=import-error,no-name-in-module
net = read_network(input_model + "_tmp.xml", input_model + "_tmp.bin")
apply_moc_transformations(net, transforms)
net.serialize(input_model + ".xml", input_model + ".bin")
path_to_mapping = input_model + ".mapping"
GenerateMappingFile(net, path_to_mapping.encode('utf-8'), extract_names)

View File

@ -5,7 +5,7 @@ import argparse
import os
from mo.pipeline.common import get_ir_version
from mo.back.ie_ir_ver_2.emitter import append_ir_info
from mo.utils.cli_parser import get_meta_info
from mo.utils.cli_parser import get_meta_info, parse_transform
from ngraph import Function # pylint: disable=no-name-in-module,import-error
from ngraph import function_to_cnn # pylint: disable=no-name-in-module,import-error
@ -15,6 +15,8 @@ def moc_emit_ir(ngraph_function: Function, argv: argparse.Namespace):
output_dir = argv.output_dir if argv.output_dir != '.' else os.getcwd()
network = function_to_cnn(ngraph_function)
from mo.back.offline_transformations import apply_moc_transformations
apply_moc_transformations(network, parse_transform(argv.transform))
orig_model_name = os.path.normpath(os.path.join(output_dir, argv.model_name))
network.serialize(orig_model_name + ".xml", orig_model_name + ".bin")