[MO] IR Reader meta_info using update (#7133)

* Initial change to allow saving graphto IR without meta_info

* Update saving and restoring functions, add more comments

* Add unit tests for define_data_type() function

* Fix wrong name

* Update condition

* Update meta_data checks

* Update data_type restoretion and missed outpurt ports handling

* Update and add new test

* Update comments

* Remove commented code

* Rename function

* Update temporary Result operations processing

* Remove define_data_type function

* Move node_normalize_output function to Op class methods

* Update comments

* Update comments
This commit is contained in:
Anton Chetverikov 2021-10-05 15:07:44 +03:00 committed by GitHub
parent cdb3e17763
commit a56d81345d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 68 additions and 51 deletions

View File

@ -2,9 +2,9 @@
# SPDX-License-Identifier: Apache-2.0
from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
from mo.graph.graph import Graph, Node
from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.result import Result
from mo.ops.op import Op
class AddFakeOutputsToSplit(MiddleReplacementPattern):
@ -23,18 +23,7 @@ class AddFakeOutputsToSplit(MiddleReplacementPattern):
def find_and_replace_pattern(self, graph: Graph):
for split_node in graph.get_op_nodes(op='Split'):
AddFakeOutputsToSplit.split_normalize_outputs(split_node)
@staticmethod
def split_normalize_outputs(node: Node):
if node.has_valid('out_ports_count') and len(node.out_edges()) < node.out_ports_count:
for p in range(node.out_ports_count):
if p not in node.out_ports():
node.add_output_port(p)
if node.out_port(p).disconnected():
res_node = Result(node.graph, {'name': node.name + '/Fake_output_{}/'.format(p),
'keep_output_port': True}).create_node()
node.out_port(p).connect(res_node.in_port(0))
Op.normalize_outputs(split_node)
class AddFakeOutputsToVariadicSplit(MiddleReplacementPattern):
@ -72,4 +61,4 @@ class AddFakeOutputsToVariadicSplit(MiddleReplacementPattern):
if not node.has_valid('out_ports_count'):
node['out_ports_count'] = len(size_splits)
AddFakeOutputsToSplit().split_normalize_outputs(node)
Op.normalize_outputs(node)

View File

@ -364,12 +364,16 @@ def add_quantization_info_section(net: Element, meta_info: dict):
def add_meta_data(net: Element, meta_info: dict):
meta = SubElement(net, 'meta_data')
SubElement(meta, 'MO_version').set('value', get_version())
parameters = SubElement(meta, 'cli_parameters')
[SubElement(parameters, str(key)).set('value', str(meta_info[key])) for key in sorted(meta_info.keys()) if
key not in ('unset', 'quantization_parameters')]
SubElement(parameters, 'unset').set('unset_cli_parameters', ', '.join(sorted(meta_info['unset'])))
if meta_info == {}:
log.warning('`meta_info` is not provided, IR will not contain appropriate section.')
else:
meta = SubElement(net, 'meta_data')
SubElement(meta, 'MO_version').set('value', get_version())
parameters = SubElement(meta, 'cli_parameters')
[SubElement(parameters, str(key)).set('value', str(meta_info[key])) for key in sorted(meta_info.keys()) if
key not in ('unset', 'quantization_parameters')]
if 'unset' in meta_info:
SubElement(parameters, 'unset').set('unset_cli_parameters', ', '.join(sorted(meta_info['unset'])))
def serialize_network(graph, net_element, unsupported):

View File

@ -294,7 +294,6 @@ class Op(object):
"""
return self.attrs.get('version', 'extension')
@classmethod
def update_node_stat(cls, node: Node, attrs: dict = None):
if attrs is None:
@ -330,6 +329,18 @@ class Op(object):
node.value = np.expand_dims(node.value, axis=-1)
node.shape = np.array(node.value.shape)
@staticmethod
def normalize_outputs(node: Node):
if node.has_valid('out_ports_count') and len(node.out_edges()) < node.out_ports_count:
from mo.ops.result import Result # Import is here to avoid circular import error
for p in range(node.out_ports_count):
if p not in node.out_ports():
node.add_output_port(p)
if node.out_port(p).disconnected():
res_node = Result(node.graph, {'name': node.name + '/Fake_output_{}/'.format(p),
'keep_output_port': True}).create_node()
node.out_port(p).connect(res_node.in_port(0))
class PermuteAttrs:
Permutation = namedtuple('Permutation', ['perm', 'inv'])

View File

@ -17,7 +17,8 @@ class PriorBox_extender(Extender):
for attr in attrs:
PriorBox_extender.attr_restore(op, attr)
if op.graph.graph['cmd_params'].framework == 'mxnet':
if 'framework' in op.graph.graph['cmd_params'] and op.graph.graph['cmd_params'].framework == 'mxnet':
# Need to use separate shape inference function as done in MO pipeline.
op['infer'] = multi_box_prior_infer_mxnet
op['stop_attr_upd'] = True

View File

@ -7,7 +7,6 @@ import os
import numpy as np
from extensions.back.TopKNormalizer import TopKNormalizer
from extensions.middle.FakeSplitOutputs import AddFakeOutputsToSplit
from extensions.ops.Cast import Cast
from extensions.ops.ReduceOps import ReduceOp
from extensions.ops.activation_ops import Activation
@ -26,7 +25,6 @@ from mo.ops.convolution import Convolution
from mo.ops.deconvolution import Deconvolution
from mo.ops.op import Op
from mo.ops.pooling import Pooling
from mo.ops.result import Result
from mo.utils.class_registration import update_registration
from mo.utils.import_extensions import import_by_path
from mo.utils.ir_reader.extender import Extender
@ -240,18 +238,6 @@ def ti_add_edge_attrs(op: Node):
i += 1
def assign_add_output_result(op: Node):
"""
Function adds necessary output result node for Assign node
:param op:
:return:
"""
assert op.soft_get('type') == 'Assign', 'Wrong operation type, {} instead of Assign!' \
''.format(op.soft_get('type'))
tmp_result = Result(op.graph, {'name': op.soft_get('name', op.id) + '/Result'}).create_node()
op.out_port(0).connect(tmp_result.in_port(0))
def copy_input_blobs(op: Node, copy_op: Node):
"""
Function copy input blob data nodes from restored graph to copied one
@ -272,17 +258,12 @@ preprocessing_op_nodes = {
'GroupConvolution': groupconv_to_conv,
'ConvolutionBackpropData': backprop_to_deconv,
'GroupConvolutionBackpropData': backprop_to_deconv,
}
# Map with postprocessing functions for nodes
postprocessing_op_nodes = {
'Assign': assign_add_output_result,
'TensorIterator': ti_add_edge_attrs,
'TopK': TopKNormalizer.normalize_outputs,
# Call normalize Split outputs for generated IR by ir-reader
'Split': AddFakeOutputsToSplit.split_normalize_outputs,
'VariadicSplit': AddFakeOutputsToSplit.split_normalize_outputs,
}
@ -377,6 +358,10 @@ def copy_graph_with_ops(graph: Graph) -> Graph:
else:
node = Op.get_op_class_by_name(op_type)(new_graph, op.attrs()).create_node()
# Fill out_ports_count attribute
if 'out_ports_count' not in node and node.soft_get('type') != 'Result':
node['out_ports_count'] = len(op.out_edges())
# This attribute is no longer needed and we can delete it
if 'ir_data_attrs' in node:
del node['ir_data_attrs']
@ -398,6 +383,19 @@ def copy_graph_with_ops(graph: Graph) -> Graph:
# Nodes postprocessing stage in new graph
for op in new_graph.get_op_nodes():
# Call normalize node outputs for restored operations to connect temporary Result operations for disconnected
# output ports. We need to do that for correct shape inference. These Result operations will be removed during
# IR emitting. For TopK operation outputs normalizing we should use specific
# function TopKNormalizer.normalize_outputs.
if op.soft_get('type') != 'TopK':
Op.normalize_outputs(op)
# Set correct_data_type attribute to Const data nodes to correct processing of restored values
if op.soft_get('type') == 'Const':
assert len(op.out_nodes()) == 1 and op.out_node(0).soft_get('kind') == 'data',\
'Const node {} not properly corrected to appropriate data node'.format(op.soft_get('name'))
op.out_node(0)['correct_data_type'] = True
restore_tensor_names(op)
# operations postprocessing with some special types

View File

@ -1,14 +1,17 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import logging as log
from copy import copy
import numpy as np
from extensions.back.ConvolutionNormalizer import ConvolutionNormalizer, ConvolutionWithGroupsResolver
from extensions.back.MarkNodesWithShapeValues import MarkNodesWithShapeValues
from extensions.back.PackBinaryWeights import PackBinaryWeights
from extensions.back.SpecialNodesFinalization import RemoveConstOps, CreateConstNodesReplacement
from extensions.back.StridedSliceMasksNormalizer import StridedSliceMasksNormalizer
from extensions.back.blob_normalizer import BlobNormalizer
from extensions.back.kaldi_remove_memory_output import KaldiRemoveMemoryOutputBackReplacementPattern
from mo.graph.graph import Graph
from mo.middle.passes.convert_data_type import data_type_str_to_precision
from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
@ -17,7 +20,6 @@ from mo.utils.class_registration import apply_replacements_list
from mo.utils.ir_engine.ir_engine import IREngine
from mo.utils.ir_reader.layer_to_class import copy_graph_with_ops, collect_extenders, collect_ops
from mo.utils.utils import get_mo_root_dir
from extensions.back.MarkNodesWithShapeValues import MarkNodesWithShapeValues
def restore_graph_from_ir(path_to_xml: str, path_to_bin: str = None) -> (Graph, dict):
@ -54,8 +56,20 @@ def save_restored_graph(graph: Graph, path: str, meta_data, name=None):
if name is None:
name = graph.name
precision = data_type_str_to_precision(graph.graph['cmd_params'].data_type)
assert precision in ['FP16', 'FP32'], 'Cannot define precision for restored model!'
if 'data_type' not in meta_data:
log.debug('Provided `meta_data` does not contain `data_type` parameter. Set `data_type`'
' parameter value to `FP32`.')
# Set data_type to FP32. All restored constants will be saved in provided data type.
data_type = 'FP32'
# We need to specify this attribute to pass graph transformations. This information will not be saved into IR.
# All constants and placeholders will be saved with same types as restored from IR
graph.graph['cmd_params'].data_type = data_type
else:
data_type = data_type_str_to_precision(graph.graph['cmd_params'].data_type)
assert data_type in ['FP16', 'FP32'], '`data_type` value {} is not supported by MO,' \
' cannot save graph'.format(data_type)
# List items order matters, do not change it.
transformation_list = [
@ -64,7 +78,6 @@ def save_restored_graph(graph: Graph, path: str, meta_data, name=None):
PackBinaryWeights,
BlobNormalizer,
ConvolutionNormalizer,
KaldiRemoveMemoryOutputBackReplacementPattern,
MarkNodesWithShapeValues,
]
@ -75,4 +88,4 @@ def save_restored_graph(graph: Graph, path: str, meta_data, name=None):
for_graph_and_each_sub_graph_recursively(graph, RemoveConstOps().find_and_replace_pattern)
for_graph_and_each_sub_graph_recursively(graph, CreateConstNodesReplacement().find_and_replace_pattern)
prepare_emit_ir(graph, precision, path, name, meta_info=meta_data)
prepare_emit_ir(graph, data_type, path, name, meta_info=meta_data)

View File

@ -92,7 +92,7 @@ def progress_bar(function: callable):
assert arg in kwargs, msg.format(arg, 'is missing')
assert kwargs[arg] is not None, msg.format(arg, 'should not be None')
if kwargs['graph'].graph['cmd_params'].progress:
if 'progress' in kwargs['graph'].graph['cmd_params'] and kwargs['graph'].graph['cmd_params'].progress:
bar_len = 20
total_replacers_count = kwargs['num_transforms']

View File

@ -2,11 +2,12 @@
# SPDX-License-Identifier: Apache-2.0
import os
import unittest
import tempfile
import unittest
from defusedxml.common import EntitiesForbidden
from mo.utils.ir_reader.restore_graph import restore_graph_from_ir
from defusedxml.common import EntitiesForbidden
class TestIRReader(unittest.TestCase):