Apply moc transformations on FE API path (#6871)
* apply moc transformations * changed type of net in apply_moc_transformations * review remarks * args->argv typo
This commit is contained in:
parent
79e9190838
commit
e4dfff387b
@ -17,16 +17,9 @@ def get_available_transformations():
|
||||
return {}
|
||||
|
||||
|
||||
def apply_offline_transformations(input_model: str, framework: str, transforms: list):
|
||||
# This variable is only needed by GenerateMappingFile transformation
|
||||
# to produce correct mapping
|
||||
extract_names = framework in ['tf', 'mxnet', 'kaldi']
|
||||
|
||||
from openvino.inference_engine import read_network # pylint: disable=import-error,no-name-in-module
|
||||
from openvino.offline_transformations import ApplyMOCTransformations, GenerateMappingFile # pylint: disable=import-error,no-name-in-module
|
||||
|
||||
net = read_network(input_model + "_tmp.xml", input_model + "_tmp.bin")
|
||||
|
||||
# net should be openvino.inference_engine.IENetwork type, but IE Engine is still optional dependency
|
||||
def apply_moc_transformations(net: object, transforms: list):
|
||||
from openvino.offline_transformations import ApplyMOCTransformations # pylint: disable=import-error,no-name-in-module
|
||||
available_transformations = get_available_transformations()
|
||||
|
||||
for name, args in transforms:
|
||||
@ -36,6 +29,18 @@ def apply_offline_transformations(input_model: str, framework: str, transforms:
|
||||
available_transformations[name](net, **args)
|
||||
|
||||
ApplyMOCTransformations(net, False)
|
||||
|
||||
|
||||
def apply_offline_transformations(input_model: str, framework: str, transforms: list):
|
||||
# This variable is only needed by GenerateMappingFile transformation
|
||||
# to produce correct mapping
|
||||
extract_names = framework in ['tf', 'mxnet', 'kaldi']
|
||||
|
||||
from openvino.inference_engine import read_network # pylint: disable=import-error,no-name-in-module
|
||||
from openvino.offline_transformations import GenerateMappingFile # pylint: disable=import-error,no-name-in-module
|
||||
|
||||
net = read_network(input_model + "_tmp.xml", input_model + "_tmp.bin")
|
||||
apply_moc_transformations(net, transforms)
|
||||
net.serialize(input_model + ".xml", input_model + ".bin")
|
||||
path_to_mapping = input_model + ".mapping"
|
||||
GenerateMappingFile(net, path_to_mapping.encode('utf-8'), extract_names)
|
||||
|
@ -5,7 +5,7 @@ import argparse
|
||||
import os
|
||||
from mo.pipeline.common import get_ir_version
|
||||
from mo.back.ie_ir_ver_2.emitter import append_ir_info
|
||||
from mo.utils.cli_parser import get_meta_info
|
||||
from mo.utils.cli_parser import get_meta_info, parse_transform
|
||||
|
||||
from ngraph import Function # pylint: disable=no-name-in-module,import-error
|
||||
from ngraph import function_to_cnn # pylint: disable=no-name-in-module,import-error
|
||||
@ -15,6 +15,8 @@ def moc_emit_ir(ngraph_function: Function, argv: argparse.Namespace):
|
||||
output_dir = argv.output_dir if argv.output_dir != '.' else os.getcwd()
|
||||
|
||||
network = function_to_cnn(ngraph_function)
|
||||
from mo.back.offline_transformations import apply_moc_transformations
|
||||
apply_moc_transformations(network, parse_transform(argv.transform))
|
||||
|
||||
orig_model_name = os.path.normpath(os.path.join(output_dir, argv.model_name))
|
||||
network.serialize(orig_model_name + ".xml", orig_model_name + ".bin")
|
||||
|
Loading…
Reference in New Issue
Block a user