[MO] IR Reader meta_info using update (#7133)

* Initial change to allow saving graphto IR without meta_info

* Update saving and restoring functions, add more comments

* Add unit tests for define_data_type() function

* Fix wrong name

* Update condition

* Update meta_data checks

* Update data_type restoretion and missed outpurt ports handling

* Update and add new test

* Update comments

* Remove commented code

* Rename function

* Update temporary Result operations processing

* Remove define_data_type function

* Move node_normalize_output function to Op class methods

* Update comments

* Update comments
This commit is contained in:
Anton Chetverikov 2021-10-05 15:07:44 +03:00 committed by GitHub
parent cdb3e17763
commit a56d81345d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 68 additions and 51 deletions

View File

@ -2,9 +2,9 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from extensions.middle.TensorIteratorMerge import TensorIteratorMerge from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
from mo.graph.graph import Graph, Node from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.result import Result from mo.ops.op import Op
class AddFakeOutputsToSplit(MiddleReplacementPattern): class AddFakeOutputsToSplit(MiddleReplacementPattern):
@ -23,18 +23,7 @@ class AddFakeOutputsToSplit(MiddleReplacementPattern):
def find_and_replace_pattern(self, graph: Graph): def find_and_replace_pattern(self, graph: Graph):
for split_node in graph.get_op_nodes(op='Split'): for split_node in graph.get_op_nodes(op='Split'):
AddFakeOutputsToSplit.split_normalize_outputs(split_node) Op.normalize_outputs(split_node)
@staticmethod
def split_normalize_outputs(node: Node):
if node.has_valid('out_ports_count') and len(node.out_edges()) < node.out_ports_count:
for p in range(node.out_ports_count):
if p not in node.out_ports():
node.add_output_port(p)
if node.out_port(p).disconnected():
res_node = Result(node.graph, {'name': node.name + '/Fake_output_{}/'.format(p),
'keep_output_port': True}).create_node()
node.out_port(p).connect(res_node.in_port(0))
class AddFakeOutputsToVariadicSplit(MiddleReplacementPattern): class AddFakeOutputsToVariadicSplit(MiddleReplacementPattern):
@ -72,4 +61,4 @@ class AddFakeOutputsToVariadicSplit(MiddleReplacementPattern):
if not node.has_valid('out_ports_count'): if not node.has_valid('out_ports_count'):
node['out_ports_count'] = len(size_splits) node['out_ports_count'] = len(size_splits)
AddFakeOutputsToSplit().split_normalize_outputs(node) Op.normalize_outputs(node)

View File

@ -364,12 +364,16 @@ def add_quantization_info_section(net: Element, meta_info: dict):
def add_meta_data(net: Element, meta_info: dict): def add_meta_data(net: Element, meta_info: dict):
meta = SubElement(net, 'meta_data') if meta_info == {}:
SubElement(meta, 'MO_version').set('value', get_version()) log.warning('`meta_info` is not provided, IR will not contain appropriate section.')
parameters = SubElement(meta, 'cli_parameters') else:
[SubElement(parameters, str(key)).set('value', str(meta_info[key])) for key in sorted(meta_info.keys()) if meta = SubElement(net, 'meta_data')
key not in ('unset', 'quantization_parameters')] SubElement(meta, 'MO_version').set('value', get_version())
SubElement(parameters, 'unset').set('unset_cli_parameters', ', '.join(sorted(meta_info['unset']))) parameters = SubElement(meta, 'cli_parameters')
[SubElement(parameters, str(key)).set('value', str(meta_info[key])) for key in sorted(meta_info.keys()) if
key not in ('unset', 'quantization_parameters')]
if 'unset' in meta_info:
SubElement(parameters, 'unset').set('unset_cli_parameters', ', '.join(sorted(meta_info['unset'])))
def serialize_network(graph, net_element, unsupported): def serialize_network(graph, net_element, unsupported):

View File

@ -294,7 +294,6 @@ class Op(object):
""" """
return self.attrs.get('version', 'extension') return self.attrs.get('version', 'extension')
@classmethod @classmethod
def update_node_stat(cls, node: Node, attrs: dict = None): def update_node_stat(cls, node: Node, attrs: dict = None):
if attrs is None: if attrs is None:
@ -330,6 +329,18 @@ class Op(object):
node.value = np.expand_dims(node.value, axis=-1) node.value = np.expand_dims(node.value, axis=-1)
node.shape = np.array(node.value.shape) node.shape = np.array(node.value.shape)
@staticmethod
def normalize_outputs(node: Node):
if node.has_valid('out_ports_count') and len(node.out_edges()) < node.out_ports_count:
from mo.ops.result import Result # Import is here to avoid circular import error
for p in range(node.out_ports_count):
if p not in node.out_ports():
node.add_output_port(p)
if node.out_port(p).disconnected():
res_node = Result(node.graph, {'name': node.name + '/Fake_output_{}/'.format(p),
'keep_output_port': True}).create_node()
node.out_port(p).connect(res_node.in_port(0))
class PermuteAttrs: class PermuteAttrs:
Permutation = namedtuple('Permutation', ['perm', 'inv']) Permutation = namedtuple('Permutation', ['perm', 'inv'])

View File

@ -17,7 +17,8 @@ class PriorBox_extender(Extender):
for attr in attrs: for attr in attrs:
PriorBox_extender.attr_restore(op, attr) PriorBox_extender.attr_restore(op, attr)
if op.graph.graph['cmd_params'].framework == 'mxnet': if 'framework' in op.graph.graph['cmd_params'] and op.graph.graph['cmd_params'].framework == 'mxnet':
# Need to use separate shape inference function as done in MO pipeline.
op['infer'] = multi_box_prior_infer_mxnet op['infer'] = multi_box_prior_infer_mxnet
op['stop_attr_upd'] = True op['stop_attr_upd'] = True

View File

@ -7,7 +7,6 @@ import os
import numpy as np import numpy as np
from extensions.back.TopKNormalizer import TopKNormalizer from extensions.back.TopKNormalizer import TopKNormalizer
from extensions.middle.FakeSplitOutputs import AddFakeOutputsToSplit
from extensions.ops.Cast import Cast from extensions.ops.Cast import Cast
from extensions.ops.ReduceOps import ReduceOp from extensions.ops.ReduceOps import ReduceOp
from extensions.ops.activation_ops import Activation from extensions.ops.activation_ops import Activation
@ -26,7 +25,6 @@ from mo.ops.convolution import Convolution
from mo.ops.deconvolution import Deconvolution from mo.ops.deconvolution import Deconvolution
from mo.ops.op import Op from mo.ops.op import Op
from mo.ops.pooling import Pooling from mo.ops.pooling import Pooling
from mo.ops.result import Result
from mo.utils.class_registration import update_registration from mo.utils.class_registration import update_registration
from mo.utils.import_extensions import import_by_path from mo.utils.import_extensions import import_by_path
from mo.utils.ir_reader.extender import Extender from mo.utils.ir_reader.extender import Extender
@ -240,18 +238,6 @@ def ti_add_edge_attrs(op: Node):
i += 1 i += 1
def assign_add_output_result(op: Node):
"""
Function adds necessary output result node for Assign node
:param op:
:return:
"""
assert op.soft_get('type') == 'Assign', 'Wrong operation type, {} instead of Assign!' \
''.format(op.soft_get('type'))
tmp_result = Result(op.graph, {'name': op.soft_get('name', op.id) + '/Result'}).create_node()
op.out_port(0).connect(tmp_result.in_port(0))
def copy_input_blobs(op: Node, copy_op: Node): def copy_input_blobs(op: Node, copy_op: Node):
""" """
Function copy input blob data nodes from restored graph to copied one Function copy input blob data nodes from restored graph to copied one
@ -272,17 +258,12 @@ preprocessing_op_nodes = {
'GroupConvolution': groupconv_to_conv, 'GroupConvolution': groupconv_to_conv,
'ConvolutionBackpropData': backprop_to_deconv, 'ConvolutionBackpropData': backprop_to_deconv,
'GroupConvolutionBackpropData': backprop_to_deconv, 'GroupConvolutionBackpropData': backprop_to_deconv,
} }
# Map with postprocessing functions for nodes # Map with postprocessing functions for nodes
postprocessing_op_nodes = { postprocessing_op_nodes = {
'Assign': assign_add_output_result,
'TensorIterator': ti_add_edge_attrs, 'TensorIterator': ti_add_edge_attrs,
'TopK': TopKNormalizer.normalize_outputs, 'TopK': TopKNormalizer.normalize_outputs,
# Call normalize Split outputs for generated IR by ir-reader
'Split': AddFakeOutputsToSplit.split_normalize_outputs,
'VariadicSplit': AddFakeOutputsToSplit.split_normalize_outputs,
} }
@ -377,6 +358,10 @@ def copy_graph_with_ops(graph: Graph) -> Graph:
else: else:
node = Op.get_op_class_by_name(op_type)(new_graph, op.attrs()).create_node() node = Op.get_op_class_by_name(op_type)(new_graph, op.attrs()).create_node()
# Fill out_ports_count attribute
if 'out_ports_count' not in node and node.soft_get('type') != 'Result':
node['out_ports_count'] = len(op.out_edges())
# This attribute is no longer needed and we can delete it # This attribute is no longer needed and we can delete it
if 'ir_data_attrs' in node: if 'ir_data_attrs' in node:
del node['ir_data_attrs'] del node['ir_data_attrs']
@ -398,6 +383,19 @@ def copy_graph_with_ops(graph: Graph) -> Graph:
# Nodes postprocessing stage in new graph # Nodes postprocessing stage in new graph
for op in new_graph.get_op_nodes(): for op in new_graph.get_op_nodes():
# Call normalize node outputs for restored operations to connect temporary Result operations for disconnected
# output ports. We need to do that for correct shape inference. These Result operations will be removed during
# IR emitting. For TopK operation outputs normalizing we should use specific
# function TopKNormalizer.normalize_outputs.
if op.soft_get('type') != 'TopK':
Op.normalize_outputs(op)
# Set correct_data_type attribute to Const data nodes to correct processing of restored values
if op.soft_get('type') == 'Const':
assert len(op.out_nodes()) == 1 and op.out_node(0).soft_get('kind') == 'data',\
'Const node {} not properly corrected to appropriate data node'.format(op.soft_get('name'))
op.out_node(0)['correct_data_type'] = True
restore_tensor_names(op) restore_tensor_names(op)
# operations postprocessing with some special types # operations postprocessing with some special types

View File

@ -1,14 +1,17 @@
# Copyright (C) 2018-2021 Intel Corporation # Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging as log
from copy import copy from copy import copy
import numpy as np
from extensions.back.ConvolutionNormalizer import ConvolutionNormalizer, ConvolutionWithGroupsResolver from extensions.back.ConvolutionNormalizer import ConvolutionNormalizer, ConvolutionWithGroupsResolver
from extensions.back.MarkNodesWithShapeValues import MarkNodesWithShapeValues
from extensions.back.PackBinaryWeights import PackBinaryWeights from extensions.back.PackBinaryWeights import PackBinaryWeights
from extensions.back.SpecialNodesFinalization import RemoveConstOps, CreateConstNodesReplacement from extensions.back.SpecialNodesFinalization import RemoveConstOps, CreateConstNodesReplacement
from extensions.back.StridedSliceMasksNormalizer import StridedSliceMasksNormalizer from extensions.back.StridedSliceMasksNormalizer import StridedSliceMasksNormalizer
from extensions.back.blob_normalizer import BlobNormalizer from extensions.back.blob_normalizer import BlobNormalizer
from extensions.back.kaldi_remove_memory_output import KaldiRemoveMemoryOutputBackReplacementPattern
from mo.graph.graph import Graph from mo.graph.graph import Graph
from mo.middle.passes.convert_data_type import data_type_str_to_precision from mo.middle.passes.convert_data_type import data_type_str_to_precision
from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
@ -17,7 +20,6 @@ from mo.utils.class_registration import apply_replacements_list
from mo.utils.ir_engine.ir_engine import IREngine from mo.utils.ir_engine.ir_engine import IREngine
from mo.utils.ir_reader.layer_to_class import copy_graph_with_ops, collect_extenders, collect_ops from mo.utils.ir_reader.layer_to_class import copy_graph_with_ops, collect_extenders, collect_ops
from mo.utils.utils import get_mo_root_dir from mo.utils.utils import get_mo_root_dir
from extensions.back.MarkNodesWithShapeValues import MarkNodesWithShapeValues
def restore_graph_from_ir(path_to_xml: str, path_to_bin: str = None) -> (Graph, dict): def restore_graph_from_ir(path_to_xml: str, path_to_bin: str = None) -> (Graph, dict):
@ -54,8 +56,20 @@ def save_restored_graph(graph: Graph, path: str, meta_data, name=None):
if name is None: if name is None:
name = graph.name name = graph.name
precision = data_type_str_to_precision(graph.graph['cmd_params'].data_type) if 'data_type' not in meta_data:
assert precision in ['FP16', 'FP32'], 'Cannot define precision for restored model!' log.debug('Provided `meta_data` does not contain `data_type` parameter. Set `data_type`'
' parameter value to `FP32`.')
# Set data_type to FP32. All restored constants will be saved in provided data type.
data_type = 'FP32'
# We need to specify this attribute to pass graph transformations. This information will not be saved into IR.
# All constants and placeholders will be saved with same types as restored from IR
graph.graph['cmd_params'].data_type = data_type
else:
data_type = data_type_str_to_precision(graph.graph['cmd_params'].data_type)
assert data_type in ['FP16', 'FP32'], '`data_type` value {} is not supported by MO,' \
' cannot save graph'.format(data_type)
# List items order matters, do not change it. # List items order matters, do not change it.
transformation_list = [ transformation_list = [
@ -64,7 +78,6 @@ def save_restored_graph(graph: Graph, path: str, meta_data, name=None):
PackBinaryWeights, PackBinaryWeights,
BlobNormalizer, BlobNormalizer,
ConvolutionNormalizer, ConvolutionNormalizer,
KaldiRemoveMemoryOutputBackReplacementPattern,
MarkNodesWithShapeValues, MarkNodesWithShapeValues,
] ]
@ -75,4 +88,4 @@ def save_restored_graph(graph: Graph, path: str, meta_data, name=None):
for_graph_and_each_sub_graph_recursively(graph, RemoveConstOps().find_and_replace_pattern) for_graph_and_each_sub_graph_recursively(graph, RemoveConstOps().find_and_replace_pattern)
for_graph_and_each_sub_graph_recursively(graph, CreateConstNodesReplacement().find_and_replace_pattern) for_graph_and_each_sub_graph_recursively(graph, CreateConstNodesReplacement().find_and_replace_pattern)
prepare_emit_ir(graph, precision, path, name, meta_info=meta_data) prepare_emit_ir(graph, data_type, path, name, meta_info=meta_data)

View File

@ -92,7 +92,7 @@ def progress_bar(function: callable):
assert arg in kwargs, msg.format(arg, 'is missing') assert arg in kwargs, msg.format(arg, 'is missing')
assert kwargs[arg] is not None, msg.format(arg, 'should not be None') assert kwargs[arg] is not None, msg.format(arg, 'should not be None')
if kwargs['graph'].graph['cmd_params'].progress: if 'progress' in kwargs['graph'].graph['cmd_params'] and kwargs['graph'].graph['cmd_params'].progress:
bar_len = 20 bar_len = 20
total_replacers_count = kwargs['num_transforms'] total_replacers_count = kwargs['num_transforms']

View File

@ -2,11 +2,12 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
import unittest
import tempfile import tempfile
import unittest
from defusedxml.common import EntitiesForbidden
from mo.utils.ir_reader.restore_graph import restore_graph_from_ir from mo.utils.ir_reader.restore_graph import restore_graph_from_ir
from defusedxml.common import EntitiesForbidden
class TestIRReader(unittest.TestCase): class TestIRReader(unittest.TestCase):