[MO] Support for common rt_info attribute in MO IR Reader (#9272)

* Support for common rt_info attribute in MO IR Reader

* Add missed change

* Moved back wrong change

* Change attr name

* Add support for rt_info for out ports

* Add emitting for rt_info

* Fix restoration error

* Add support for rt_info for input ports

* Add more comments

* Set correct layout attr to restored graph
This commit is contained in:
Anton Chetverikov 2021-12-29 00:59:48 +03:00 committed by GitHub
parent 3cef513495
commit 3e6951c1da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 96 additions and 16 deletions

View File

@ -8,11 +8,12 @@ from collections import defaultdict
import numpy as np
from openvino.tools.mo.back.pass_separator import BackFinish
from openvino.tools.mo.ops.tensor_iterator import TensorIterator
from openvino.tools.mo.back.replacement import BackReplacementPattern
from openvino.tools.mo.graph.graph import Graph
from openvino.tools.mo.ops.const import Const
from openvino.tools.mo.ops.tensor_iterator import TensorIterator
from openvino.tools.mo.utils.error import Error
from openvino.tools.mo.utils.runtime_info import RTInfo
from openvino.tools.mo.utils.utils import refer_to_faq_msg
@ -86,6 +87,7 @@ class CreateConstNodesReplacement(BackReplacementPattern):
'override_output_shape': node.has_valid('force_shape'),
'force_type': node.soft_get('force_type', None),
'correct_data_type': node.soft_get('correct_data_type', None),
'rt_info': node.soft_get('rt_info', RTInfo()),
}).create_node()
const_node.add_input_port(0)
graph.add_edges_from([(const_node_name, node.id, {'out': 0})])

View File

@ -140,6 +140,21 @@ def xml_ports(node: Node, element: Element, edges: Element):
assert node.graph.node[u]['shape'] is not None, 'Input shape is not calculated properly for node {}'.format(
node.id)
xml_shape(node.graph.node[u]['shape'], port)
# support saving rt_info passed from IR Reader
port_id = d['in']
if node.has('restored_input_ports') and port_id in node.restored_input_ports:
port_rt_info_value = node.restored_input_ports[port_id][2]
if port_rt_info_value != {}:
port_rt_info = SubElement(port, 'rt_info')
for (name, version), info_elem in port_rt_info_value.items():
attribute = SubElement(port_rt_info, 'attribute')
attribute.set('name', name)
attribute.set('version', str(version))
params = info_elem.serialize(node) if not isinstance(info_elem, dict) else info_elem
for key, value in params.items():
attribute.set(key, value)
# u is a data node that has a single producer, let's find it
assert (node.graph.node[u]['kind'] == 'data')
in_nodes = list(node.graph.in_edges(u, data=True))
@ -176,6 +191,18 @@ def xml_ports(node: Node, element: Element, edges: Element):
port.set('names', ','.join(tensor_names))
xml_shape(node.graph.node[v]['shape'], port)
# support saving rt_info passed from IR Reader
if node.has('ports') and port_id in node.ports:
port_rt_info_value = node.ports[port_id][2]
if port_rt_info_value != []:
port_rt_info = SubElement(port, 'rt_info')
for (name, version), info_elem in port_rt_info_value.items():
attribute = SubElement(port_rt_info, 'attribute')
attribute.set('name', name)
attribute.set('version', str(version))
params = info_elem.serialize(node) if not isinstance(info_elem, dict) else info_elem
for key, value in params.items():
attribute.set(key, value)
def xml_consts(graph: Graph, node: Node, element: Element):
blobs = None # sub-element that will be created on-demand
@ -258,10 +285,7 @@ def serialize_runtime_info(node, parent_element: Element):
attribute = SubElement(rt_info, 'attribute')
attribute.set('name', name)
attribute.set('version', str(version))
params = info_elem.serialize(node)
if len(params) == 0:
rt_info.remove(attribute)
continue
params = info_elem.serialize(node) if not isinstance(info_elem, dict) else info_elem
for key, value in params.items():
attribute.set(key, value)
if len(rt_info.attrib) == 0 and len(list(rt_info)) == 0:

View File

@ -144,6 +144,7 @@ def add_constant_operations(graph):
if len(node.in_nodes()) == 0 and len(node.out_nodes()) != 0:
# It's necessary to import here due to cycle dependencies
from openvino.tools.mo.ops.const import Const
from openvino.tools.mo.utils.runtime_info import RTInfo
name = node.soft_get('name', node.id)
new_name = re.sub(r'\/Output_\d+\/Data_(.?)+', '', name)
const_node = Const(graph, dict(value=node.value, name=new_name,
@ -151,6 +152,7 @@ def add_constant_operations(graph):
override_output_shape=node.has_valid('force_shape'),
force_type=node.soft_get('force_type', None),
correct_data_type=node.soft_get('correct_data_type', False),
rt_info=node.soft_get('rt_info', RTInfo()),
)).create_node()
graph.add_edges_from([(const_node.id, node.id, {'out': 0})])

View File

@ -5,14 +5,13 @@ import hashlib
import logging as log
import os
import sys
from defusedxml import defuse_stdlib
import defusedxml.ElementTree as ET
from argparse import Namespace
from collections import namedtuple, defaultdict
from collections import namedtuple, defaultdict, OrderedDict
from pathlib import Path
import defusedxml.ElementTree as ET
import numpy as np
from defusedxml import defuse_stdlib
from openvino.tools.mo.front.common.partial_infer.utils import dynamic_dimension_value, shape_array
from openvino.tools.mo.graph.graph import Node, Graph
@ -59,7 +58,9 @@ class IREngine(object):
self.graph.graph['hashes'] = {}
self.graph.graph['ir_version'] = int(xml_root.attrib['version']) if xml_root.attrib.get('version') is not None else None
self.graph.graph['layout'] = 'NCHW'
self.graph.graph['layout'] = 'NCHW' # We set layout to NCHW as default value and
# changing it in __rt_info_check_layout if it will be necessary
self.graph.name = xml_root.attrib['name'] if xml_root.attrib.get('name') is not None else None
# Parse XML
@ -204,7 +205,7 @@ class IREngine(object):
layer_id = layer.attrib['id']
layer_attrs = layer.attrib
layer_attrs.update({'ports': {}, 'kind': 'op'})
layer_attrs.update({'ports': {}, 'restored_input_ports': {}, 'kind': 'op'})
inputs_counter = 0
@ -224,22 +225,50 @@ class IREngine(object):
layer_attrs.update(new_attrs)
elif attr.tag == 'input':
inputs_counter = len(attr)
input = attr
for port in input:
port_id = int(port.attrib['id'])
input_shape = []
port_rt_info = {}
for dim in port:
if dim.tag == "dim":
input_shape.append(int(dim.text))
if dim.tag == 'rt_info':
for attr in dim:
port_rt_info.update(self.__read_rt_info_common(attr))
self.__rt_info_check_layout(attr)
input_shape = shape_array([d if d != -1 else dynamic_dimension_value for d in input_shape])
in_tensor_names = None
if 'names' in port.attrib:
in_tensor_names = port.attrib['names']
# special attribute to pass information about operation input ports
layer_attrs['restored_input_ports'].update({port_id: (input_shape, in_tensor_names, port_rt_info)})
elif attr.tag == 'output':
output = attr
for port in output:
port_id = int(port.attrib['id'])
output_shape = []
port_rt_info = {}
for dim in port:
if dim.tag == "dim":
output_shape.append(int(dim.text))
if dim.tag == 'rt_info':
for attr in dim:
port_rt_info.update(self.__read_rt_info_common(attr))
self.__rt_info_check_layout(attr)
output_shape = shape_array([d if d != -1 else dynamic_dimension_value for d in output_shape])
out_tensor_names = None
if 'names' in port.attrib:
out_tensor_names = port.attrib['names']
layer_attrs['ports'].update({port_id: (output_shape, out_tensor_names)})
# special attribute to pass information about operation input ports
# NOTE: renaming or structure changing of this attribute may have big impact on tests
layer_attrs['ports'].update({port_id: (output_shape, out_tensor_names, port_rt_info)})
elif attr.tag == 'blobs':
in_port = inputs_counter
for blob_attr in attr:
@ -460,8 +489,10 @@ class IREngine(object):
attr_name = attr.attrib['name']
if attr_name == 'old_api_map_order':
rt_info.info.update(self.__read_old_api_map_order(attr, layer.attrib['type']))
if attr_name == 'old_api_map_element_type':
elif attr_name == 'old_api_map_element_type':
rt_info.info.update(self.__read_old_api_map_element_type(attr, layer.attrib['type']))
else:
rt_info.info.update((self.__read_rt_info_common(attr)))
layer_attrs.update({'rt_info': rt_info})
return layer_attrs
@ -487,3 +518,21 @@ class IREngine(object):
old_api_map = OldAPIMapElementType(version=version)
old_api_map.set_legacy_type(element_type)
return {('old_api_map_element_type', version): old_api_map}
@staticmethod
def __read_rt_info_common(attr):
attr_name = attr.attrib['name']
version = int(attr.attrib['version'])
rt_info = OrderedDict()
for key in attr.attrib:
if key not in ('name', 'version'):
rt_info[key] = attr.attrib[key]
return {(attr_name, version): rt_info}
def __rt_info_check_layout(self, attr):
graph_layout = None
for key in attr.attrib:
if key == 'layout':
graph_layout = attr.attrib[key].replace(',', '').strip('[] ')# .strip(']').strip(',').strip(' ')
if graph_layout is not None:
self.graph.graph['layout'] = graph_layout

View File

@ -278,8 +278,8 @@ def restore_tensor_names(op: Node):
for out_port in op.ports:
# op.ports is our internal attribute, dictionary, where keys are numbers of output ports
# and values are tuples with shape and tensor name:
# {out_port_idx_1: (out_port_idx_1_shape, out_port_idx_1_tensor_name),
# out_port_idx_2: (out_port_idx_2_shape, out_port_idx_2_tensor_name)}
# {out_port_idx_1: (out_port_idx_1_shape, out_port_idx_1_tensor_name, out_port_idx_1_rt_info),
# out_port_idx_2: (out_port_idx_2_shape, out_port_idx_2_tensor_name, out_port_idx_2_rt_info)}
out_tensor_names = op.ports[out_port][1]
# handle Constant operations with old style output port numbering
@ -405,6 +405,9 @@ def copy_graph_with_ops(graph: Graph) -> Graph:
'Const node {} not properly corrected to appropriate data node'.format(op.soft_get('name'))
op.out_node(0)['correct_data_type'] = True
if op.has_and_set('rt_info'):
op.out_node(0)['rt_info'] = op.rt_info
restore_tensor_names(op)
# operations postprocessing with some special types