From 2c4afd8cd40c71eee5c832707a517f567a469f4c Mon Sep 17 00:00:00 2001 From: Anton Chetverikov Date: Thu, 18 Feb 2021 10:38:04 +0300 Subject: [PATCH] Tensor names support in MO IR Reader (#4194) * Added attributes save modes * Added tensor names to IR * Reformat code * Add support for tensor names in MO IR Reader * Unit tests and code refactoring * Fixed error * Code refactoring * Code refactoring * Code refactoring * Error fixed * Error fixed * Bug fixed * Bug fixed * Additional unit tests and comments * Small update * Update fake infer function * Update names restoring * optimize imports * Add support for old-style constants and for commas in reader * Added dest mode in Fuse Mul * Update default values * Fix missed debug info in some specific cases * Fix a lot of issues with missedand wrong names provoding * Resolve review comments * Update test IR's * Refactor and simplify code * More simplification * Remove unneccessary changes * model-optimizer/mo/utils/ir_reader/layer_to_class_test.py * Add separate tests for names restoring * Update copyright year * Apply review comments Co-authored-by: Anastasia Popova --- .../mo/utils/ir_engine/ir_engine.py | 10 +++-- .../mo/utils/ir_reader/extender.py | 4 +- .../mo/utils/ir_reader/layer_to_class.py | 34 +++++++++++++++- .../mo/utils/ir_reader/layer_to_class_test.py | 40 ++++++++++++++++++- 4 files changed, 80 insertions(+), 8 deletions(-) diff --git a/model-optimizer/mo/utils/ir_engine/ir_engine.py b/model-optimizer/mo/utils/ir_engine/ir_engine.py index 5424b7bcaf9..dd63b5f199d 100644 --- a/model-optimizer/mo/utils/ir_engine/ir_engine.py +++ b/model-optimizer/mo/utils/ir_engine/ir_engine.py @@ -1,5 +1,5 @@ """ - Copyright (C) 2018-2020 Intel Corporation + Copyright (C) 2018-2021 Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -123,7 +123,7 @@ class IREngine(object): data_nodes = {} for port in self.graph.node[node]['ports']: data = self.graph.unique_id(prefix='data_') - self.graph.add_node(data, **{'kind': 'data', 'shape': self.graph.node[node]['ports'][port], + self.graph.add_node(data, **{'kind': 'data', 'shape': self.graph.node[node]['ports'][port][0], 'value': None}) self.graph.add_edges_from([(node, data, {'out': port})]) data_nodes.update({port: data}) @@ -232,7 +232,11 @@ class IREngine(object): for dim in port: output_shape.append(int(dim.text)) - layer_attrs['ports'].update({port_id: output_shape}) + out_tensor_names = None + if 'names' in port.attrib: + out_tensor_names = port.attrib['names'] + + layer_attrs['ports'].update({port_id: (output_shape, out_tensor_names)}) elif attr.tag == 'blobs': in_port = inputs_counter for blob_attr in attr: diff --git a/model-optimizer/mo/utils/ir_reader/extender.py b/model-optimizer/mo/utils/ir_reader/extender.py index fd6604ac018..2bc2998458e 100644 --- a/model-optimizer/mo/utils/ir_reader/extender.py +++ b/model-optimizer/mo/utils/ir_reader/extender.py @@ -1,5 +1,5 @@ """ - Copyright (C) 2018-2020 Intel Corporation + Copyright (C) 2018-2021 Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -51,5 +51,5 @@ class Extender(object): def const_shape_infer(node: Node): i = len(node.in_nodes()) for num in node.out_nodes(): - node.out_node(num).shape = int64_array(node.ports[i]) + node.out_node(num).shape = int64_array(node.ports[i][0]) i += 1 diff --git a/model-optimizer/mo/utils/ir_reader/layer_to_class.py b/model-optimizer/mo/utils/ir_reader/layer_to_class.py index b027f5fab37..37b6a42672e 100644 --- a/model-optimizer/mo/utils/ir_reader/layer_to_class.py +++ b/model-optimizer/mo/utils/ir_reader/layer_to_class.py @@ -1,5 +1,5 @@ """ - Copyright (C) 2018-2020 Intel Corporation + Copyright (C) 2018-2021 Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -281,6 +281,35 @@ postprocessing_op_nodes = { } +def restore_tensor_names(op: Node): + for out_port in op.ports: + # op.ports is our internal attribute, dictionary, where keys are numbers of output ports + # and values are tuples with shape and tensor name: + # {out_port_idx_1: (out_port_idx_1_shape, out_port_idx_1_tensor_name), + # out_port_idx_2: (out_port_idx_2_shape, out_port_idx_2_tensor_name)} + out_tensor_names = op.ports[out_port][1] + + # handle Constant operations with old style output port numbering + if op.soft_get('type') == 'Const': + assert len(op.ports) == 1, 'Something wrong with Constant node: {}, wrong number ' \ + 'of output ports: {}!'.format(op.soft_get('name'), len(op.ports)) + out_port = 0 + + out_port = out_port - len(op.in_nodes()) + + if out_tensor_names is not None: + # handle tensor names with commas and add them to dictionary as separate items + if out_tensor_names.find(',') >= 0: + str_to_replace = '' + out_tensor_names = (out_tensor_names.replace('\\,', str_to_replace)).split(',') + op.out_node(out_port)['fw_tensor_debug_info'] = [] + for out_tensor_name in out_tensor_names: + out_tensor_name = out_tensor_name.replace(str_to_replace, ',') + op.out_node(out_port)['fw_tensor_debug_info'].append((out_tensor_name, out_port, out_tensor_name)) + else: + op.out_node(out_port)['fw_tensor_debug_info'] = [(out_tensor_names, out_port, out_tensor_names)] + + def copy_graph_with_ops(graph: Graph) -> Graph: """ Function to copy graph and apply extenders to appropriate nodes @@ -342,6 +371,9 @@ def copy_graph_with_ops(graph: Graph) -> Graph: # Nodes postprocessing stage in new graph for op in new_graph.get_op_nodes(): + restore_tensor_names(op) + + # operations postprocessing with some special types if op.soft_get('type') in postprocessing_op_nodes: postprocessing_op_nodes[op.type](op) diff --git a/model-optimizer/mo/utils/ir_reader/layer_to_class_test.py b/model-optimizer/mo/utils/ir_reader/layer_to_class_test.py index 06bcec1c596..70e6fbc1de7 100644 --- a/model-optimizer/mo/utils/ir_reader/layer_to_class_test.py +++ b/model-optimizer/mo/utils/ir_reader/layer_to_class_test.py @@ -1,5 +1,5 @@ """ - Copyright (c) 2018-2020 Intel Corporation + Copyright (c) 2018-2021 Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ from generator import generator, generate from mo.graph.graph import Node from mo.utils.ir_engine.compare_graphs import compare_graphs -from mo.utils.ir_reader.layer_to_class import groupconv_to_conv +from mo.utils.ir_reader.layer_to_class import groupconv_to_conv, restore_tensor_names from mo.utils.unittest.graph import build_graph @@ -107,3 +107,39 @@ class TestFunction(unittest.TestCase): (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp) + + def test_restore_tensor_names(self): + + shape = [1, 3, 224, 224] + + nodes_attributes = { + 'input': {'kind': 'op', 'type': 'Parameter', 'ports': {0: (shape, 'abc,def')}}, + 'input_data': {'shape': shape, 'kind': 'data'}, + 'add': {'kind': 'op', 'type': 'Add', 'ports': {2: (shape, 'ghi\,jkl')}}, + 'add_data': {'shape': shape, 'kind': 'data'}, + 'add_const': {'kind': 'op', 'type': 'Const', 'ports': {0: (shape, 'mno,pqr\,stu')}}, + 'add_const_data': {'shape': shape, 'kind': 'data'}, + 'result': {'kind': 'op', 'type': 'Result', 'ports': {0: (shape, None)}} + } + + edges = [('input', 'input_data'), + ('input_data', 'add'), + ('add_const', 'add_const_data'), + ('add_const_data', 'add'), + ('add', 'add_data'), + ('add_data', 'result'), + ] + + graph = build_graph(nodes_attributes, edges, nodes_with_edges_only=True) + + for op in graph.get_op_nodes(): + restore_tensor_names(op) + + node_1 = Node(graph, 'input_data') + node_2 = Node(graph, 'add_data') + node_3 = Node(graph, 'add_const_data') + + assert node_1['fw_tensor_debug_info'] == [('abc', 0, 'abc'), ('def', 0, 'def')], 'Restored debug info is wrong!' + assert node_2['fw_tensor_debug_info'] == [('ghi,jkl', 0, 'ghi,jkl')], 'Restored debug info is wrong!' + assert node_3['fw_tensor_debug_info'] == [('mno', 0, 'mno'), ('pqr,stu', 0, 'pqr,stu')],\ + 'Restored debug info is wrong!'