Remove incorrect layout changing (#9764)

This commit is contained in:
Anton Chetverikov 2022-01-25 14:10:06 +03:00 committed by GitHub
parent 9a522137bf
commit 88903ee7ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -58,8 +58,12 @@ class IREngine(object):
self.graph.graph['hashes'] = {} self.graph.graph['hashes'] = {}
self.graph.graph['ir_version'] = int(xml_root.attrib['version']) if xml_root.attrib.get('version') is not None else None self.graph.graph['ir_version'] = int(xml_root.attrib['version']) if xml_root.attrib.get('version') is not None else None
self.graph.graph['layout'] = 'NCHW' # We set layout to NCHW as default value and
# changing it in __rt_info_check_layout if it will be necessary # NOTE: THis is MO internal attribute, it cannot be used for
# defining graph input layout. We set it to NCHW as in MO back stage
# during conversion for correct shape inference of layout specific
# operations (ExtractImagePatches, SpaceToDepth, etc.)
self.graph.graph['layout'] = 'NCHW'
self.graph.name = xml_root.attrib['name'] if xml_root.attrib.get('name') is not None else None self.graph.name = xml_root.attrib['name'] if xml_root.attrib.get('name') is not None else None
@ -237,7 +241,6 @@ class IREngine(object):
if dim.tag == 'rt_info': if dim.tag == 'rt_info':
for attr in dim: for attr in dim:
port_rt_info.update(self.__read_rt_info_common(attr)) port_rt_info.update(self.__read_rt_info_common(attr))
self.__rt_info_check_layout(attr)
input_shape = shape_array([d if d != -1 else dynamic_dimension_value for d in input_shape]) input_shape = shape_array([d if d != -1 else dynamic_dimension_value for d in input_shape])
@ -259,7 +262,6 @@ class IREngine(object):
if dim.tag == 'rt_info': if dim.tag == 'rt_info':
for attr in dim: for attr in dim:
port_rt_info.update(self.__read_rt_info_common(attr)) port_rt_info.update(self.__read_rt_info_common(attr))
self.__rt_info_check_layout(attr)
output_shape = shape_array([d if d != -1 else dynamic_dimension_value for d in output_shape]) output_shape = shape_array([d if d != -1 else dynamic_dimension_value for d in output_shape])
@ -528,11 +530,3 @@ class IREngine(object):
if key not in ('name', 'version'): if key not in ('name', 'version'):
rt_info[key] = attr.attrib[key] rt_info[key] = attr.attrib[key]
return {(attr_name, version): rt_info} return {(attr_name, version): rt_info}
def __rt_info_check_layout(self, attr):
graph_layout = None
for key in attr.attrib:
if key == 'layout':
graph_layout = attr.attrib[key].replace(',', '').strip('[] ')# .strip(']').strip(',').strip(' ')
if graph_layout is not None:
self.graph.graph['layout'] = graph_layout