[MO] IR Reader meta_info using update (#7133)
* Initial change to allow saving graphto IR without meta_info * Update saving and restoring functions, add more comments * Add unit tests for define_data_type() function * Fix wrong name * Update condition * Update meta_data checks * Update data_type restoretion and missed outpurt ports handling * Update and add new test * Update comments * Remove commented code * Rename function * Update temporary Result operations processing * Remove define_data_type function * Move node_normalize_output function to Op class methods * Update comments * Update comments
This commit is contained in:
parent
cdb3e17763
commit
a56d81345d
@ -2,9 +2,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.graph.graph import Graph
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.result import Result
|
||||
from mo.ops.op import Op
|
||||
|
||||
|
||||
class AddFakeOutputsToSplit(MiddleReplacementPattern):
|
||||
@ -23,18 +23,7 @@ class AddFakeOutputsToSplit(MiddleReplacementPattern):
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for split_node in graph.get_op_nodes(op='Split'):
|
||||
AddFakeOutputsToSplit.split_normalize_outputs(split_node)
|
||||
|
||||
@staticmethod
|
||||
def split_normalize_outputs(node: Node):
|
||||
if node.has_valid('out_ports_count') and len(node.out_edges()) < node.out_ports_count:
|
||||
for p in range(node.out_ports_count):
|
||||
if p not in node.out_ports():
|
||||
node.add_output_port(p)
|
||||
if node.out_port(p).disconnected():
|
||||
res_node = Result(node.graph, {'name': node.name + '/Fake_output_{}/'.format(p),
|
||||
'keep_output_port': True}).create_node()
|
||||
node.out_port(p).connect(res_node.in_port(0))
|
||||
Op.normalize_outputs(split_node)
|
||||
|
||||
|
||||
class AddFakeOutputsToVariadicSplit(MiddleReplacementPattern):
|
||||
@ -72,4 +61,4 @@ class AddFakeOutputsToVariadicSplit(MiddleReplacementPattern):
|
||||
if not node.has_valid('out_ports_count'):
|
||||
node['out_ports_count'] = len(size_splits)
|
||||
|
||||
AddFakeOutputsToSplit().split_normalize_outputs(node)
|
||||
Op.normalize_outputs(node)
|
||||
|
@ -364,12 +364,16 @@ def add_quantization_info_section(net: Element, meta_info: dict):
|
||||
|
||||
|
||||
def add_meta_data(net: Element, meta_info: dict):
|
||||
meta = SubElement(net, 'meta_data')
|
||||
SubElement(meta, 'MO_version').set('value', get_version())
|
||||
parameters = SubElement(meta, 'cli_parameters')
|
||||
[SubElement(parameters, str(key)).set('value', str(meta_info[key])) for key in sorted(meta_info.keys()) if
|
||||
key not in ('unset', 'quantization_parameters')]
|
||||
SubElement(parameters, 'unset').set('unset_cli_parameters', ', '.join(sorted(meta_info['unset'])))
|
||||
if meta_info == {}:
|
||||
log.warning('`meta_info` is not provided, IR will not contain appropriate section.')
|
||||
else:
|
||||
meta = SubElement(net, 'meta_data')
|
||||
SubElement(meta, 'MO_version').set('value', get_version())
|
||||
parameters = SubElement(meta, 'cli_parameters')
|
||||
[SubElement(parameters, str(key)).set('value', str(meta_info[key])) for key in sorted(meta_info.keys()) if
|
||||
key not in ('unset', 'quantization_parameters')]
|
||||
if 'unset' in meta_info:
|
||||
SubElement(parameters, 'unset').set('unset_cli_parameters', ', '.join(sorted(meta_info['unset'])))
|
||||
|
||||
|
||||
def serialize_network(graph, net_element, unsupported):
|
||||
|
@ -294,7 +294,6 @@ class Op(object):
|
||||
"""
|
||||
return self.attrs.get('version', 'extension')
|
||||
|
||||
|
||||
@classmethod
|
||||
def update_node_stat(cls, node: Node, attrs: dict = None):
|
||||
if attrs is None:
|
||||
@ -330,6 +329,18 @@ class Op(object):
|
||||
node.value = np.expand_dims(node.value, axis=-1)
|
||||
node.shape = np.array(node.value.shape)
|
||||
|
||||
@staticmethod
|
||||
def normalize_outputs(node: Node):
|
||||
if node.has_valid('out_ports_count') and len(node.out_edges()) < node.out_ports_count:
|
||||
from mo.ops.result import Result # Import is here to avoid circular import error
|
||||
for p in range(node.out_ports_count):
|
||||
if p not in node.out_ports():
|
||||
node.add_output_port(p)
|
||||
if node.out_port(p).disconnected():
|
||||
res_node = Result(node.graph, {'name': node.name + '/Fake_output_{}/'.format(p),
|
||||
'keep_output_port': True}).create_node()
|
||||
node.out_port(p).connect(res_node.in_port(0))
|
||||
|
||||
|
||||
class PermuteAttrs:
|
||||
Permutation = namedtuple('Permutation', ['perm', 'inv'])
|
||||
|
@ -17,7 +17,8 @@ class PriorBox_extender(Extender):
|
||||
for attr in attrs:
|
||||
PriorBox_extender.attr_restore(op, attr)
|
||||
|
||||
if op.graph.graph['cmd_params'].framework == 'mxnet':
|
||||
if 'framework' in op.graph.graph['cmd_params'] and op.graph.graph['cmd_params'].framework == 'mxnet':
|
||||
# Need to use separate shape inference function as done in MO pipeline.
|
||||
op['infer'] = multi_box_prior_infer_mxnet
|
||||
op['stop_attr_upd'] = True
|
||||
|
||||
|
@ -7,7 +7,6 @@ import os
|
||||
import numpy as np
|
||||
|
||||
from extensions.back.TopKNormalizer import TopKNormalizer
|
||||
from extensions.middle.FakeSplitOutputs import AddFakeOutputsToSplit
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.ops.ReduceOps import ReduceOp
|
||||
from extensions.ops.activation_ops import Activation
|
||||
@ -26,7 +25,6 @@ from mo.ops.convolution import Convolution
|
||||
from mo.ops.deconvolution import Deconvolution
|
||||
from mo.ops.op import Op
|
||||
from mo.ops.pooling import Pooling
|
||||
from mo.ops.result import Result
|
||||
from mo.utils.class_registration import update_registration
|
||||
from mo.utils.import_extensions import import_by_path
|
||||
from mo.utils.ir_reader.extender import Extender
|
||||
@ -240,18 +238,6 @@ def ti_add_edge_attrs(op: Node):
|
||||
i += 1
|
||||
|
||||
|
||||
def assign_add_output_result(op: Node):
|
||||
"""
|
||||
Function adds necessary output result node for Assign node
|
||||
:param op:
|
||||
:return:
|
||||
"""
|
||||
assert op.soft_get('type') == 'Assign', 'Wrong operation type, {} instead of Assign!' \
|
||||
''.format(op.soft_get('type'))
|
||||
tmp_result = Result(op.graph, {'name': op.soft_get('name', op.id) + '/Result'}).create_node()
|
||||
op.out_port(0).connect(tmp_result.in_port(0))
|
||||
|
||||
|
||||
def copy_input_blobs(op: Node, copy_op: Node):
|
||||
"""
|
||||
Function copy input blob data nodes from restored graph to copied one
|
||||
@ -272,17 +258,12 @@ preprocessing_op_nodes = {
|
||||
'GroupConvolution': groupconv_to_conv,
|
||||
'ConvolutionBackpropData': backprop_to_deconv,
|
||||
'GroupConvolutionBackpropData': backprop_to_deconv,
|
||||
|
||||
}
|
||||
|
||||
# Map with postprocessing functions for nodes
|
||||
postprocessing_op_nodes = {
|
||||
'Assign': assign_add_output_result,
|
||||
'TensorIterator': ti_add_edge_attrs,
|
||||
'TopK': TopKNormalizer.normalize_outputs,
|
||||
# Call normalize Split outputs for generated IR by ir-reader
|
||||
'Split': AddFakeOutputsToSplit.split_normalize_outputs,
|
||||
'VariadicSplit': AddFakeOutputsToSplit.split_normalize_outputs,
|
||||
}
|
||||
|
||||
|
||||
@ -377,6 +358,10 @@ def copy_graph_with_ops(graph: Graph) -> Graph:
|
||||
else:
|
||||
node = Op.get_op_class_by_name(op_type)(new_graph, op.attrs()).create_node()
|
||||
|
||||
# Fill out_ports_count attribute
|
||||
if 'out_ports_count' not in node and node.soft_get('type') != 'Result':
|
||||
node['out_ports_count'] = len(op.out_edges())
|
||||
|
||||
# This attribute is no longer needed and we can delete it
|
||||
if 'ir_data_attrs' in node:
|
||||
del node['ir_data_attrs']
|
||||
@ -398,6 +383,19 @@ def copy_graph_with_ops(graph: Graph) -> Graph:
|
||||
|
||||
# Nodes postprocessing stage in new graph
|
||||
for op in new_graph.get_op_nodes():
|
||||
# Call normalize node outputs for restored operations to connect temporary Result operations for disconnected
|
||||
# output ports. We need to do that for correct shape inference. These Result operations will be removed during
|
||||
# IR emitting. For TopK operation outputs normalizing we should use specific
|
||||
# function TopKNormalizer.normalize_outputs.
|
||||
if op.soft_get('type') != 'TopK':
|
||||
Op.normalize_outputs(op)
|
||||
|
||||
# Set correct_data_type attribute to Const data nodes to correct processing of restored values
|
||||
if op.soft_get('type') == 'Const':
|
||||
assert len(op.out_nodes()) == 1 and op.out_node(0).soft_get('kind') == 'data',\
|
||||
'Const node {} not properly corrected to appropriate data node'.format(op.soft_get('name'))
|
||||
op.out_node(0)['correct_data_type'] = True
|
||||
|
||||
restore_tensor_names(op)
|
||||
|
||||
# operations postprocessing with some special types
|
||||
|
@ -1,14 +1,17 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging as log
|
||||
from copy import copy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.back.ConvolutionNormalizer import ConvolutionNormalizer, ConvolutionWithGroupsResolver
|
||||
from extensions.back.MarkNodesWithShapeValues import MarkNodesWithShapeValues
|
||||
from extensions.back.PackBinaryWeights import PackBinaryWeights
|
||||
from extensions.back.SpecialNodesFinalization import RemoveConstOps, CreateConstNodesReplacement
|
||||
from extensions.back.StridedSliceMasksNormalizer import StridedSliceMasksNormalizer
|
||||
from extensions.back.blob_normalizer import BlobNormalizer
|
||||
from extensions.back.kaldi_remove_memory_output import KaldiRemoveMemoryOutputBackReplacementPattern
|
||||
from mo.graph.graph import Graph
|
||||
from mo.middle.passes.convert_data_type import data_type_str_to_precision
|
||||
from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
|
||||
@ -17,7 +20,6 @@ from mo.utils.class_registration import apply_replacements_list
|
||||
from mo.utils.ir_engine.ir_engine import IREngine
|
||||
from mo.utils.ir_reader.layer_to_class import copy_graph_with_ops, collect_extenders, collect_ops
|
||||
from mo.utils.utils import get_mo_root_dir
|
||||
from extensions.back.MarkNodesWithShapeValues import MarkNodesWithShapeValues
|
||||
|
||||
|
||||
def restore_graph_from_ir(path_to_xml: str, path_to_bin: str = None) -> (Graph, dict):
|
||||
@ -54,8 +56,20 @@ def save_restored_graph(graph: Graph, path: str, meta_data, name=None):
|
||||
if name is None:
|
||||
name = graph.name
|
||||
|
||||
precision = data_type_str_to_precision(graph.graph['cmd_params'].data_type)
|
||||
assert precision in ['FP16', 'FP32'], 'Cannot define precision for restored model!'
|
||||
if 'data_type' not in meta_data:
|
||||
log.debug('Provided `meta_data` does not contain `data_type` parameter. Set `data_type`'
|
||||
' parameter value to `FP32`.')
|
||||
# Set data_type to FP32. All restored constants will be saved in provided data type.
|
||||
data_type = 'FP32'
|
||||
|
||||
# We need to specify this attribute to pass graph transformations. This information will not be saved into IR.
|
||||
# All constants and placeholders will be saved with same types as restored from IR
|
||||
graph.graph['cmd_params'].data_type = data_type
|
||||
else:
|
||||
data_type = data_type_str_to_precision(graph.graph['cmd_params'].data_type)
|
||||
|
||||
assert data_type in ['FP16', 'FP32'], '`data_type` value {} is not supported by MO,' \
|
||||
' cannot save graph'.format(data_type)
|
||||
|
||||
# List items order matters, do not change it.
|
||||
transformation_list = [
|
||||
@ -64,7 +78,6 @@ def save_restored_graph(graph: Graph, path: str, meta_data, name=None):
|
||||
PackBinaryWeights,
|
||||
BlobNormalizer,
|
||||
ConvolutionNormalizer,
|
||||
KaldiRemoveMemoryOutputBackReplacementPattern,
|
||||
MarkNodesWithShapeValues,
|
||||
]
|
||||
|
||||
@ -75,4 +88,4 @@ def save_restored_graph(graph: Graph, path: str, meta_data, name=None):
|
||||
for_graph_and_each_sub_graph_recursively(graph, RemoveConstOps().find_and_replace_pattern)
|
||||
for_graph_and_each_sub_graph_recursively(graph, CreateConstNodesReplacement().find_and_replace_pattern)
|
||||
|
||||
prepare_emit_ir(graph, precision, path, name, meta_info=meta_data)
|
||||
prepare_emit_ir(graph, data_type, path, name, meta_info=meta_data)
|
||||
|
@ -92,7 +92,7 @@ def progress_bar(function: callable):
|
||||
assert arg in kwargs, msg.format(arg, 'is missing')
|
||||
assert kwargs[arg] is not None, msg.format(arg, 'should not be None')
|
||||
|
||||
if kwargs['graph'].graph['cmd_params'].progress:
|
||||
if 'progress' in kwargs['graph'].graph['cmd_params'] and kwargs['graph'].graph['cmd_params'].progress:
|
||||
bar_len = 20
|
||||
total_replacers_count = kwargs['num_transforms']
|
||||
|
||||
|
@ -2,11 +2,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from defusedxml.common import EntitiesForbidden
|
||||
|
||||
from mo.utils.ir_reader.restore_graph import restore_graph_from_ir
|
||||
from defusedxml.common import EntitiesForbidden
|
||||
|
||||
|
||||
class TestIRReader(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user