[MO] Support for common rt_info attribute in MO IR Reader (#9272)
* Support for common rt_info attribute in MO IR Reader * Add missed change * Moved back wrong change * Change attr name * Add support for rt_info for out ports * Add emitting for rt_info * Fix restoration error * Add support for rt_info for input ports * Add more comments * Set correct layout attr to restored graph
This commit is contained in:
parent
3cef513495
commit
3e6951c1da
@ -8,11 +8,12 @@ from collections import defaultdict
|
||||
import numpy as np
|
||||
|
||||
from openvino.tools.mo.back.pass_separator import BackFinish
|
||||
from openvino.tools.mo.ops.tensor_iterator import TensorIterator
|
||||
from openvino.tools.mo.back.replacement import BackReplacementPattern
|
||||
from openvino.tools.mo.graph.graph import Graph
|
||||
from openvino.tools.mo.ops.const import Const
|
||||
from openvino.tools.mo.ops.tensor_iterator import TensorIterator
|
||||
from openvino.tools.mo.utils.error import Error
|
||||
from openvino.tools.mo.utils.runtime_info import RTInfo
|
||||
from openvino.tools.mo.utils.utils import refer_to_faq_msg
|
||||
|
||||
|
||||
@ -86,6 +87,7 @@ class CreateConstNodesReplacement(BackReplacementPattern):
|
||||
'override_output_shape': node.has_valid('force_shape'),
|
||||
'force_type': node.soft_get('force_type', None),
|
||||
'correct_data_type': node.soft_get('correct_data_type', None),
|
||||
'rt_info': node.soft_get('rt_info', RTInfo()),
|
||||
}).create_node()
|
||||
const_node.add_input_port(0)
|
||||
graph.add_edges_from([(const_node_name, node.id, {'out': 0})])
|
||||
|
@ -140,6 +140,21 @@ def xml_ports(node: Node, element: Element, edges: Element):
|
||||
assert node.graph.node[u]['shape'] is not None, 'Input shape is not calculated properly for node {}'.format(
|
||||
node.id)
|
||||
xml_shape(node.graph.node[u]['shape'], port)
|
||||
|
||||
# support saving rt_info passed from IR Reader
|
||||
port_id = d['in']
|
||||
if node.has('restored_input_ports') and port_id in node.restored_input_ports:
|
||||
port_rt_info_value = node.restored_input_ports[port_id][2]
|
||||
if port_rt_info_value != {}:
|
||||
port_rt_info = SubElement(port, 'rt_info')
|
||||
for (name, version), info_elem in port_rt_info_value.items():
|
||||
attribute = SubElement(port_rt_info, 'attribute')
|
||||
attribute.set('name', name)
|
||||
attribute.set('version', str(version))
|
||||
params = info_elem.serialize(node) if not isinstance(info_elem, dict) else info_elem
|
||||
for key, value in params.items():
|
||||
attribute.set(key, value)
|
||||
|
||||
# u is a data node that has a single producer, let's find it
|
||||
assert (node.graph.node[u]['kind'] == 'data')
|
||||
in_nodes = list(node.graph.in_edges(u, data=True))
|
||||
@ -176,6 +191,18 @@ def xml_ports(node: Node, element: Element, edges: Element):
|
||||
port.set('names', ','.join(tensor_names))
|
||||
xml_shape(node.graph.node[v]['shape'], port)
|
||||
|
||||
# support saving rt_info passed from IR Reader
|
||||
if node.has('ports') and port_id in node.ports:
|
||||
port_rt_info_value = node.ports[port_id][2]
|
||||
if port_rt_info_value != []:
|
||||
port_rt_info = SubElement(port, 'rt_info')
|
||||
for (name, version), info_elem in port_rt_info_value.items():
|
||||
attribute = SubElement(port_rt_info, 'attribute')
|
||||
attribute.set('name', name)
|
||||
attribute.set('version', str(version))
|
||||
params = info_elem.serialize(node) if not isinstance(info_elem, dict) else info_elem
|
||||
for key, value in params.items():
|
||||
attribute.set(key, value)
|
||||
|
||||
def xml_consts(graph: Graph, node: Node, element: Element):
|
||||
blobs = None # sub-element that will be created on-demand
|
||||
@ -258,10 +285,7 @@ def serialize_runtime_info(node, parent_element: Element):
|
||||
attribute = SubElement(rt_info, 'attribute')
|
||||
attribute.set('name', name)
|
||||
attribute.set('version', str(version))
|
||||
params = info_elem.serialize(node)
|
||||
if len(params) == 0:
|
||||
rt_info.remove(attribute)
|
||||
continue
|
||||
params = info_elem.serialize(node) if not isinstance(info_elem, dict) else info_elem
|
||||
for key, value in params.items():
|
||||
attribute.set(key, value)
|
||||
if len(rt_info.attrib) == 0 and len(list(rt_info)) == 0:
|
||||
|
@ -144,6 +144,7 @@ def add_constant_operations(graph):
|
||||
if len(node.in_nodes()) == 0 and len(node.out_nodes()) != 0:
|
||||
# It's necessary to import here due to cycle dependencies
|
||||
from openvino.tools.mo.ops.const import Const
|
||||
from openvino.tools.mo.utils.runtime_info import RTInfo
|
||||
name = node.soft_get('name', node.id)
|
||||
new_name = re.sub(r'\/Output_\d+\/Data_(.?)+', '', name)
|
||||
const_node = Const(graph, dict(value=node.value, name=new_name,
|
||||
@ -151,6 +152,7 @@ def add_constant_operations(graph):
|
||||
override_output_shape=node.has_valid('force_shape'),
|
||||
force_type=node.soft_get('force_type', None),
|
||||
correct_data_type=node.soft_get('correct_data_type', False),
|
||||
rt_info=node.soft_get('rt_info', RTInfo()),
|
||||
)).create_node()
|
||||
graph.add_edges_from([(const_node.id, node.id, {'out': 0})])
|
||||
|
||||
|
@ -5,14 +5,13 @@ import hashlib
|
||||
import logging as log
|
||||
import os
|
||||
import sys
|
||||
|
||||
from defusedxml import defuse_stdlib
|
||||
import defusedxml.ElementTree as ET
|
||||
from argparse import Namespace
|
||||
from collections import namedtuple, defaultdict
|
||||
from collections import namedtuple, defaultdict, OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
import defusedxml.ElementTree as ET
|
||||
import numpy as np
|
||||
from defusedxml import defuse_stdlib
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import dynamic_dimension_value, shape_array
|
||||
from openvino.tools.mo.graph.graph import Node, Graph
|
||||
@ -59,7 +58,9 @@ class IREngine(object):
|
||||
self.graph.graph['hashes'] = {}
|
||||
|
||||
self.graph.graph['ir_version'] = int(xml_root.attrib['version']) if xml_root.attrib.get('version') is not None else None
|
||||
self.graph.graph['layout'] = 'NCHW'
|
||||
self.graph.graph['layout'] = 'NCHW' # We set layout to NCHW as default value and
|
||||
# changing it in __rt_info_check_layout if it will be necessary
|
||||
|
||||
self.graph.name = xml_root.attrib['name'] if xml_root.attrib.get('name') is not None else None
|
||||
|
||||
# Parse XML
|
||||
@ -204,7 +205,7 @@ class IREngine(object):
|
||||
layer_id = layer.attrib['id']
|
||||
|
||||
layer_attrs = layer.attrib
|
||||
layer_attrs.update({'ports': {}, 'kind': 'op'})
|
||||
layer_attrs.update({'ports': {}, 'restored_input_ports': {}, 'kind': 'op'})
|
||||
|
||||
inputs_counter = 0
|
||||
|
||||
@ -224,22 +225,50 @@ class IREngine(object):
|
||||
layer_attrs.update(new_attrs)
|
||||
elif attr.tag == 'input':
|
||||
inputs_counter = len(attr)
|
||||
|
||||
input = attr
|
||||
for port in input:
|
||||
port_id = int(port.attrib['id'])
|
||||
input_shape = []
|
||||
port_rt_info = {}
|
||||
for dim in port:
|
||||
if dim.tag == "dim":
|
||||
input_shape.append(int(dim.text))
|
||||
if dim.tag == 'rt_info':
|
||||
for attr in dim:
|
||||
port_rt_info.update(self.__read_rt_info_common(attr))
|
||||
self.__rt_info_check_layout(attr)
|
||||
|
||||
input_shape = shape_array([d if d != -1 else dynamic_dimension_value for d in input_shape])
|
||||
|
||||
in_tensor_names = None
|
||||
if 'names' in port.attrib:
|
||||
in_tensor_names = port.attrib['names']
|
||||
|
||||
# special attribute to pass information about operation input ports
|
||||
layer_attrs['restored_input_ports'].update({port_id: (input_shape, in_tensor_names, port_rt_info)})
|
||||
elif attr.tag == 'output':
|
||||
output = attr
|
||||
for port in output:
|
||||
port_id = int(port.attrib['id'])
|
||||
output_shape = []
|
||||
port_rt_info = {}
|
||||
for dim in port:
|
||||
if dim.tag == "dim":
|
||||
output_shape.append(int(dim.text))
|
||||
if dim.tag == 'rt_info':
|
||||
for attr in dim:
|
||||
port_rt_info.update(self.__read_rt_info_common(attr))
|
||||
self.__rt_info_check_layout(attr)
|
||||
|
||||
output_shape = shape_array([d if d != -1 else dynamic_dimension_value for d in output_shape])
|
||||
|
||||
out_tensor_names = None
|
||||
if 'names' in port.attrib:
|
||||
out_tensor_names = port.attrib['names']
|
||||
|
||||
layer_attrs['ports'].update({port_id: (output_shape, out_tensor_names)})
|
||||
# special attribute to pass information about operation input ports
|
||||
# NOTE: renaming or structure changing of this attribute may have big impact on tests
|
||||
layer_attrs['ports'].update({port_id: (output_shape, out_tensor_names, port_rt_info)})
|
||||
elif attr.tag == 'blobs':
|
||||
in_port = inputs_counter
|
||||
for blob_attr in attr:
|
||||
@ -460,8 +489,10 @@ class IREngine(object):
|
||||
attr_name = attr.attrib['name']
|
||||
if attr_name == 'old_api_map_order':
|
||||
rt_info.info.update(self.__read_old_api_map_order(attr, layer.attrib['type']))
|
||||
if attr_name == 'old_api_map_element_type':
|
||||
elif attr_name == 'old_api_map_element_type':
|
||||
rt_info.info.update(self.__read_old_api_map_element_type(attr, layer.attrib['type']))
|
||||
else:
|
||||
rt_info.info.update((self.__read_rt_info_common(attr)))
|
||||
|
||||
layer_attrs.update({'rt_info': rt_info})
|
||||
return layer_attrs
|
||||
@ -487,3 +518,21 @@ class IREngine(object):
|
||||
old_api_map = OldAPIMapElementType(version=version)
|
||||
old_api_map.set_legacy_type(element_type)
|
||||
return {('old_api_map_element_type', version): old_api_map}
|
||||
|
||||
@staticmethod
|
||||
def __read_rt_info_common(attr):
|
||||
attr_name = attr.attrib['name']
|
||||
version = int(attr.attrib['version'])
|
||||
rt_info = OrderedDict()
|
||||
for key in attr.attrib:
|
||||
if key not in ('name', 'version'):
|
||||
rt_info[key] = attr.attrib[key]
|
||||
return {(attr_name, version): rt_info}
|
||||
|
||||
def __rt_info_check_layout(self, attr):
|
||||
graph_layout = None
|
||||
for key in attr.attrib:
|
||||
if key == 'layout':
|
||||
graph_layout = attr.attrib[key].replace(',', '').strip('[] ')# .strip(']').strip(',').strip(' ')
|
||||
if graph_layout is not None:
|
||||
self.graph.graph['layout'] = graph_layout
|
||||
|
@ -278,8 +278,8 @@ def restore_tensor_names(op: Node):
|
||||
for out_port in op.ports:
|
||||
# op.ports is our internal attribute, dictionary, where keys are numbers of output ports
|
||||
# and values are tuples with shape and tensor name:
|
||||
# {out_port_idx_1: (out_port_idx_1_shape, out_port_idx_1_tensor_name),
|
||||
# out_port_idx_2: (out_port_idx_2_shape, out_port_idx_2_tensor_name)}
|
||||
# {out_port_idx_1: (out_port_idx_1_shape, out_port_idx_1_tensor_name, out_port_idx_1_rt_info),
|
||||
# out_port_idx_2: (out_port_idx_2_shape, out_port_idx_2_tensor_name, out_port_idx_2_rt_info)}
|
||||
out_tensor_names = op.ports[out_port][1]
|
||||
|
||||
# handle Constant operations with old style output port numbering
|
||||
@ -405,6 +405,9 @@ def copy_graph_with_ops(graph: Graph) -> Graph:
|
||||
'Const node {} not properly corrected to appropriate data node'.format(op.soft_get('name'))
|
||||
op.out_node(0)['correct_data_type'] = True
|
||||
|
||||
if op.has_and_set('rt_info'):
|
||||
op.out_node(0)['rt_info'] = op.rt_info
|
||||
|
||||
restore_tensor_names(op)
|
||||
|
||||
# operations postprocessing with some special types
|
||||
|
Loading…
Reference in New Issue
Block a user