Changed "out_port_id" attribute setting in mapping file to store tensor names. (#5344)

* Removed port id from fw_tensor_debug_info attribute.

* Added port number to tensor names in kaldi, mxnet. Fixed Const naming.

* Sort imports.
This commit is contained in:
Anastasia Popova 2021-04-29 14:05:35 +03:00 committed by GitHub
parent b1a4a73328
commit 07214d0a47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 69 additions and 67 deletions

View File

@ -2,12 +2,13 @@
# SPDX-License-Identifier: Apache-2.0
import logging as log
import re
from collections import defaultdict
import numpy as np
from extensions.back.pass_separator import BackFinish
from extensions.ops.tensor_iterator import TensorIterator, get_internal_node_by_layer_id
from extensions.ops.tensor_iterator import TensorIterator
from mo.back.replacement import BackReplacementPattern
from mo.graph.graph import Graph
from mo.ops.const import Const
@ -77,7 +78,8 @@ class CreateConstNodesReplacement(BackReplacementPattern):
if self._check_bin_attrs(node):
if node.has_valid('value'):
const_node_name = graph.unique_id(node.id + '_const')
const_node_name = node.soft_get('name', node.id)
const_node_name = graph.unique_id(re.sub(r'\/Output_\d+\/Data_(.?)+', '', const_node_name))
log.debug("Added Const node '{}'".format(const_node_name))
const_node = Const(graph, {'name': const_node_name, 'value': node.value,
'force_shape': node.soft_get('force_shape', None),

View File

@ -81,7 +81,7 @@ class LoopExtractor(FrontExtractorOp):
'out': src_port,
'in': dst_port,
'name': inp,
'fw_tensor_debug_info': [(src_id, dst_port, inp)],
'fw_tensor_debug_info': [(src_id, inp)],
'in_attrs': ['in', 'name'],
'out_attrs': ['out', 'name'],
'data_attrs': ['fw_tensor_debug_info']
@ -136,7 +136,7 @@ class LoopExtractor(FrontExtractorOp):
main_graph.add_edge(src_node, loop_node.id, **{'out': src_port,
'in': next_loop_input_port_idx,
'name': src_node,
'fw_tensor_debug_info': [(src_node, next_loop_input_port_idx, tensor_name)],
'fw_tensor_debug_info': [(src_node, tensor_name)],
'in_attrs': ['in', 'name'],
'out_attrs': ['out', 'name'],
'data_attrs': ['fw_tensor_debug_info']}

View File

@ -319,8 +319,8 @@ def add_edge_caffe(graph: Graph, bottom: str, dst_layer: str, blob_producers: di
'out': src_port,
'in': dst_port,
'name': bottom,
# debug anchor for a framework name, out port and tensor name
'fw_tensor_debug_info': [(blob_producers[bottom][2], src_port, bottom)],
# debug anchor for a framework name and tensor name
'fw_tensor_debug_info': [(blob_producers[bottom][2], bottom)],
'in_attrs': ['in', 'name'],
'out_attrs': ['out', 'name'],
'data_attrs': ['fw_tensor_debug_info']

View File

@ -860,8 +860,9 @@ def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False,
input_op = Parameter(graph, dict(shape=shape, data_type=data_type, initial_node_name=node_id,
name=get_new_placeholder_name(node_id, is_out_port, port)))
fw_name = Node(graph, node_id).soft_get('name')
edge_attrs = {'in': port, 'out': 0, 'in_attrs': ['in'], 'out_attrs': ['out'],
'fw_tensor_debug_info': [(Node(graph, node_id).soft_get('name'), port)],
'fw_tensor_debug_info': [(fw_name, fw_name)],
'data_attrs': ['fw_tensor_debug_info']}
if not data:
if is_out_port:

View File

@ -332,7 +332,7 @@ def create_edge_attrs(prev_layer_id: str, next_layer_id: str, tensor_name: str,
'out': out_port,
'in': in_port,
'name': next_layer_id,
'fw_tensor_debug_info': [(prev_layer_id, out_port, tensor_name)],
'fw_tensor_debug_info': [(prev_layer_id, tensor_name + ":" + str(out_port))],
'in_attrs': ['in', 'permutation'],
'out_attrs': ['out', 'permutation'],
'data_attrs': ['fw_tensor_debug_info']

View File

@ -114,8 +114,8 @@ def create_mxnet_edge(src_node_id: str, dst_node_id: str, src_port: int, dst_por
edge_attrs = {
'in': src_port,
'out': dst_port,
# debug anchor for framework name, out port and tensor name
'fw_tensor_debug_info': [(framework_name, dst_port, framework_name)],
# debug anchor for framework name and tensor name
'fw_tensor_debug_info': [(framework_name, framework_name + ":" + str(dst_port))],
'in_attrs': ['in'],
'out_attrs': ['out'],
'data_attrs': ['fw_tensor_debug_info']

View File

@ -96,7 +96,7 @@ def protobuf2nx(graph: Graph, pb):
'out': src_port,
'in': dst_port,
'name': inp,
'fw_tensor_debug_info': [(src_id, src_port, inp)],
'fw_tensor_debug_info': [(src_id, inp)],
'in_attrs': ['in', 'name'],
'out_attrs': ['out', 'name'],
'data_attrs': ['fw_tensor_debug_info']
@ -110,7 +110,7 @@ def protobuf2nx(graph: Graph, pb):
'out': src_port,
'in': 0,
'name': out,
'fw_tensor_debug_info': [(fw_name, src_port, out)],
'fw_tensor_debug_info': [(fw_name, out)],
'in_attrs': ['in', 'name'],
'out_attrs': ['out', 'name'],
'data_attrs': ['fw_tensor_debug_info']

View File

@ -38,7 +38,7 @@ def create_tf_edge(src_node_id: str, dst_node_id: str, in_port: int):
'in': in_port,
'out': src_port,
# debug anchor for a framework name, out port and tensor name
'fw_tensor_debug_info': [(src_node_id, src_port, tensor_name)],
'fw_tensor_debug_info': [(src_node_id, tensor_name)],
'in_attrs': ['in', 'control_flow_edge', 'permutation'],
'out_attrs': ['out', 'permutation'],
'data_attrs': ['fw_tensor_debug_info'],

View File

@ -269,8 +269,8 @@ class Port:
if attrs['fw_tensor_debug_info'] is None:
return tensor_names_list
for attr in attrs['fw_tensor_debug_info']:
if attr is not None and len(attr) >= 3:
tensor_name = attr[2]
if attr is not None and len(attr) >= 2:
tensor_name = attr[1]
if tensor_name is not None and len(tensor_name) > 0:
tensor_names_list.append(tensor_name.replace(',', '\\,'))
return tensor_names_list

View File

@ -302,9 +302,9 @@ def restore_tensor_names(op: Node):
op.out_node(out_port)['fw_tensor_debug_info'] = []
for out_tensor_name in out_tensor_names:
out_tensor_name = out_tensor_name.replace(str_to_replace, ',')
op.out_node(out_port)['fw_tensor_debug_info'].append((out_tensor_name, out_port, out_tensor_name))
op.out_node(out_port)['fw_tensor_debug_info'].append((out_tensor_name, out_tensor_name))
else:
op.out_node(out_port)['fw_tensor_debug_info'] = [(out_tensor_names, out_port, out_tensor_names)]
op.out_node(out_port)['fw_tensor_debug_info'] = [(out_tensor_names, out_tensor_names)]
def copy_graph_with_ops(graph: Graph) -> Graph:

View File

@ -13,8 +13,8 @@ nodes = {
**regular_op('Op2', {'type': 'Op2', 'kind': 'op', 'op': 'Op2'}),
**result('result1'),
**result('result2'),
'Op1_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op1', 0, 'Op1_tensor')]},
'Op2_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op1', 0, 'Op2_tensor')]},
'Op1_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op1', 'Op1_tensor')]},
'Op2_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op1', 'Op2_tensor')]},
}

View File

@ -22,7 +22,7 @@ class TestsOutputCut(unittest.TestCase):
def test_case1(self):
graph = build_graph(nodes, [('Parameter1', 'FakeOutput1',
{'in': 0, 'out': 0, 'fw_tensor_debug_info':
[('Parameter1', 0, 'Parameter1_tensor_name')]})])
[('Parameter1', 'Parameter1_tensor_name')]})])
graph.graph['packed_outputs'] = None
graph.graph['user_shapes'] = None
@ -31,18 +31,18 @@ class TestsOutputCut(unittest.TestCase):
param1 = Node(graph, 'Parameter1')
self.assertTrue(param1.out_node()['type'] == 'Result')
self.assertTrue(param1.out_edge()['fw_tensor_debug_info'] == [('Parameter1', 0, 'Parameter1_tensor_name')])
self.assertTrue(param1.out_edge()['fw_tensor_debug_info'] == [('Parameter1', 'Parameter1_tensor_name')])
self.assertTrue(graph.get_op_nodes(name='FakeOutput1') == [])
def test_case2(self):
graph = build_graph(nodes, [('Parameter1', 'Op1'),
('Op1', 'FakeOutput1',
{'in': 1, 'out': 1, 'fw_tensor_debug_info':
[('Op1', 0, 'Op1_tensor_name')]}),
[('Op1', 'Op1_tensor_name')]}),
('Parameter1', 'Op2'),
('Op2', 'FakeOutput2',
{'in': 2, 'out': 3,
'fw_tensor_debug_info': [('Op2', 0, 'Op2_tensor_name')]})])
'fw_tensor_debug_info': [('Op2', 'Op2_tensor_name')]})])
graph.graph['packed_outputs'] = None
graph.graph['user_shapes'] = None
@ -53,8 +53,8 @@ class TestsOutputCut(unittest.TestCase):
op2 = Node(graph, 'Op2')
self.assertTrue(op1.out_node(1)['type'] == 'Result')
self.assertTrue(op2.out_node(3)['type'] == 'Result')
self.assertTrue(op1.out_edge(1)['fw_tensor_debug_info'] == [('Op1', 0, 'Op1_tensor_name')])
self.assertTrue(op2.out_edge(3)['fw_tensor_debug_info'] == [('Op2', 0, 'Op2_tensor_name')])
self.assertTrue(op1.out_edge(1)['fw_tensor_debug_info'] == [('Op1', 'Op1_tensor_name')])
self.assertTrue(op2.out_edge(3)['fw_tensor_debug_info'] == [('Op2', 'Op2_tensor_name')])
self.assertTrue(graph.get_op_nodes(name='FakeOutput1') == [])
self.assertTrue(graph.get_op_nodes(name='FakeOutput2') == [])

View File

@ -13,9 +13,9 @@ nodes = {
**regular_op('Op2', {'type': 'Op2', 'kind': 'op', 'op': 'Op2'}),
**regular_op('NewOp', {'type': 'NewOp', 'kind': 'op', 'op': 'NewOp'}),
'input_data': {'kind': 'data', 'fw_tensor_debug_info': [('input', 0, 'input')]},
'Op1_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op1', 0, 'Op1')]},
'Op2_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op2', 0, 'Op2')]},
'input_data': {'kind': 'data', 'fw_tensor_debug_info': [('input', 'input')]},
'Op1_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op1', 'Op1')]},
'Op2_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op2', 'Op2')]},
'NewOp_data': {'kind': 'data'},
}
@ -33,9 +33,9 @@ class TestsFront(unittest.TestCase):
def test_case1_merge(self):
graph = build_graph(nodes,
[('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input')]})])
[('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input')]})])
graph_ref = build_graph(nodes, [
('input', 'NewOp', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input')]})])
('input', 'NewOp', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input')]})])
input_node = Node(graph, 'input')
new_node = Node(graph, 'NewOp')
@ -50,9 +50,9 @@ class TestsFront(unittest.TestCase):
def test_case1_source(self):
graph = build_graph(nodes, [
('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input')]})])
('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input')]})])
graph_ref = build_graph(nodes, [
('input', 'NewOp', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input')]})])
('input', 'NewOp', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input')]})])
input_node = Node(graph, 'input')
new_node = Node(graph, 'NewOp')
@ -67,7 +67,7 @@ class TestsFront(unittest.TestCase):
def test_case1_dest(self):
graph = build_graph(nodes, [
('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input')]})])
('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input')]})])
graph_ref = build_graph(nodes, [
('input', 'NewOp', {'in': 0, 'out': 0})])
@ -84,9 +84,9 @@ class TestsFront(unittest.TestCase):
def test_case2_merge(self):
graph = build_graph(nodes,
[('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input')]})])
[('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input')]})])
graph_ref = build_graph(nodes, [
('input', 'NewOp', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input')]})])
('input', 'NewOp', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input')]})])
op1_node = Node(graph, 'Op1')
new_node = Node(graph, 'NewOp')
@ -101,9 +101,9 @@ class TestsFront(unittest.TestCase):
def test_case2_source(self):
graph = build_graph(nodes,
[('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input')]})])
[('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input')]})])
graph_ref = build_graph(nodes, [
('input', 'NewOp', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input')]})])
('input', 'NewOp', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input')]})])
op1_node = Node(graph, 'Op1')
new_node = Node(graph, 'NewOp')
@ -118,7 +118,7 @@ class TestsFront(unittest.TestCase):
def test_case2_dest(self):
graph = build_graph(nodes,
[('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input')]})])
[('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input')]})])
graph_ref = build_graph(nodes, [('input', 'NewOp', {'in': 0, 'out': 0})])
op1_node = Node(graph, 'Op1')
@ -134,9 +134,9 @@ class TestsFront(unittest.TestCase):
def test_case3_merge(self):
graph = build_graph(nodes,
[('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input')]})])
[('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input')]})])
graph_ref = build_graph(nodes, [
('NewOp', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input')]})])
('NewOp', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input')]})])
op1_node = Node(graph, 'Op1')
new_node = Node(graph, 'NewOp')
@ -151,7 +151,7 @@ class TestsFront(unittest.TestCase):
def test_case3_source(self):
graph = build_graph(nodes,
[('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input')]})])
[('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input')]})])
graph_ref = build_graph(nodes, [('NewOp', 'Op1', {'in': 0, 'out': 0})])
op1_node = Node(graph, 'Op1')
@ -167,9 +167,9 @@ class TestsFront(unittest.TestCase):
def test_case3_dest(self):
graph = build_graph(nodes,
[('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input')]})])
[('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input')]})])
graph_ref = build_graph(nodes, [
('NewOp', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input')]})])
('NewOp', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input')]})])
op1_node = Node(graph, 'Op1')
new_node = Node(graph, 'NewOp')
@ -184,9 +184,9 @@ class TestsFront(unittest.TestCase):
def test_case4_merge(self):
graph = build_graph(nodes,
[('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input')]})])
[('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input')]})])
graph_ref = build_graph(nodes, [
('NewOp', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input')]})])
('NewOp', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input')]})])
op1_node = Node(graph, 'Op1')
new_node = Node(graph, 'NewOp')
@ -427,7 +427,7 @@ class TestsMiddle(unittest.TestCase):
graph_ref = build_graph(nodes, [('input', 'input_data'), ('NewOp', 'NewOp_data'), ('NewOp_data', 'Op1')])
new_op_data = Node(graph_ref, 'NewOp_data')
new_op_data['fw_tensor_debug_info'] = [('input', 0, 'input')]
new_op_data['fw_tensor_debug_info'] = [('input', 'input')]
input_data = Node(graph_ref, 'input_data')
del input_data['fw_tensor_debug_info']
@ -459,7 +459,7 @@ class TestsMiddle(unittest.TestCase):
graph_ref = build_graph(nodes, [('input', 'input_data'), ('NewOp', 'NewOp_data'), ('NewOp_data', 'Op1')])
new_op_data = Node(graph_ref, 'NewOp_data')
new_op_data['fw_tensor_debug_info'] = [('input', 0, 'input')]
new_op_data['fw_tensor_debug_info'] = [('input', 'input')]
input_data = Node(graph_ref, 'input_data')
del input_data['fw_tensor_debug_info']
@ -478,7 +478,7 @@ class TestsMiddle(unittest.TestCase):
graph_ref = build_graph(nodes, [('input', 'input_data'), ('NewOp', 'NewOp_data'), ('NewOp_data', 'Op1')])
new_op_data = Node(graph_ref, 'NewOp_data')
new_op_data['fw_tensor_debug_info'] = [('input', 0, 'input')]
new_op_data['fw_tensor_debug_info'] = [('input', 'input')]
op1_node = Node(graph, 'Op1')
new_node = Node(graph, 'NewOp')
@ -507,7 +507,7 @@ class TestsMiddle(unittest.TestCase):
graph_ref = build_graph(nodes, [('input', 'input_data'), ('NewOp', 'NewOp_data'), ('NewOp_data', 'Op1')])
new_op_data = Node(graph_ref, 'NewOp_data')
new_op_data['fw_tensor_debug_info'] = [('input', 0, 'input')]
new_op_data['fw_tensor_debug_info'] = [('input', 'input')]
op1_node = Node(graph, 'Op1')
new_node = Node(graph, 'NewOp')
@ -525,7 +525,7 @@ class TestsMiddle(unittest.TestCase):
('Op1', 'Op1_data'), ('input_data', 'Op2')])
input_data = Node(graph_ref, 'input_data')
input_data['fw_tensor_debug_info'] = [('input', 0, 'input'), ('Op1', 0, 'Op1')]
input_data['fw_tensor_debug_info'] = [('input', 'input'), ('Op1', 'Op1')]
op1_data = Node(graph_ref, 'Op1_data')
del op1_data['fw_tensor_debug_info']
@ -544,7 +544,7 @@ class TestsMiddle(unittest.TestCase):
('Op1', 'Op1_data'), ('input_data', 'Op2')])
input_data = Node(graph_ref, 'input_data')
input_data['fw_tensor_debug_info'] = [('input', 0, 'input')]
input_data['fw_tensor_debug_info'] = [('input', 'input')]
op1_node = Node(graph, 'Op1')
op1_node.out_port(0).get_connection().set_source(op1_node.in_port(0).get_source(), "source")
@ -560,7 +560,7 @@ class TestsMiddle(unittest.TestCase):
('Op1', 'Op1_data'), ('input_data', 'Op2')])
input_data = Node(graph_ref, 'input_data')
input_data['fw_tensor_debug_info'] = [('Op1', 0, 'Op1')]
input_data['fw_tensor_debug_info'] = [('Op1', 'Op1')]
op1_data = Node(graph_ref, 'Op1_data')
del op1_data['fw_tensor_debug_info']
@ -579,7 +579,7 @@ class TestsMiddle(unittest.TestCase):
('Op1', 'Op1_data')])
input_data = Node(graph_ref, 'input_data')
input_data['fw_tensor_debug_info'] = [('input', 0, 'input'), ('Op1', 0, 'Op1')]
input_data['fw_tensor_debug_info'] = [('input', 'input'), ('Op1', 'Op1')]
op1_node = Node(graph, 'Op1')
op1_node.in_port(0).get_connection().set_destination(op1_node.out_port(0).get_destination(), "merge")
@ -595,7 +595,7 @@ class TestsMiddle(unittest.TestCase):
('Op1', 'Op1_data')])
input_data = Node(graph_ref, 'input_data')
input_data['fw_tensor_debug_info'] = [('input', 0, 'input')]
input_data['fw_tensor_debug_info'] = [('input', 'input')]
op1_node = Node(graph, 'Op1')
op1_node.in_port(0).get_connection().set_destination(op1_node.out_port(0).get_destination(), "source")
@ -611,7 +611,7 @@ class TestsMiddle(unittest.TestCase):
('Op1', 'Op1_data')])
input_data = Node(graph_ref, 'input_data')
input_data['fw_tensor_debug_info'] = [('Op1', 0, 'Op1')]
input_data['fw_tensor_debug_info'] = [('Op1', 'Op1')]
op1_node = Node(graph, 'Op1')
op1_node.in_port(0).get_connection().set_destination(op1_node.out_port(0).get_destination(), "dest")

View File

@ -12,18 +12,18 @@ nodes = {
**regular_op('Op2', {'type': 'Op2', 'kind': 'op', 'op': 'Op2'}),
**regular_op('Op3', {'type': 'Op3', 'kind': 'op', 'op': 'Op3'}),
'input_data': {'kind': 'data', 'fw_tensor_debug_info': [('input', 0, 'input'), ('Op1', 0, 'Op1,Op2')]},
'Op1_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op1', 0, 'Op1,Op2')]},
'input_data': {'kind': 'data', 'fw_tensor_debug_info': [('input', 'input'), ('Op1', 'Op1,Op2')]},
'Op1_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op1', 'Op1,Op2')]},
'Op2_data': {'kind': 'data'},
'Op3_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op3', 0, 'Op3')]},
'Op3_data': {'kind': 'data', 'fw_tensor_debug_info': [('Op3', 'Op3')]},
}
class TestsGetTensorNames(unittest.TestCase):
def test_front(self):
graph = build_graph(nodes,
[('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input'),
('Op1', 0, 'Op1,Op2')]})])
[('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input'),
('Op1', 'Op1,Op2')]})])
graph.stage = 'front'
input_node = Node(graph, 'input')
self.assertTrue(input_node.out_port(0).get_tensor_names() == ['input', 'Op1\\,Op2'])
@ -72,10 +72,9 @@ class TestsGetTensorNames(unittest.TestCase):
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op3', 'input', 'Op1\\,Op2'])
def test_reconnect_front_case1(self):
graph = build_graph(nodes, [('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 0, 'input'),
('Op1', 0,
'Op1,Op2')]}),
('Op3', 'Op2', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('Op3', 0, 'Op3')]})])
graph = build_graph(nodes, [('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input'),
('Op1', 'Op1,Op2')]}),
('Op3', 'Op2', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('Op3', 'Op3')]})])
graph.stage = 'front'
input_node = Node(graph, 'input')

View File

@ -126,7 +126,7 @@ class TestFunction(unittest.TestCase):
node_2 = Node(graph, 'add_data')
node_3 = Node(graph, 'add_const_data')
assert node_1['fw_tensor_debug_info'] == [('abc', 0, 'abc'), ('def', 0, 'def')], 'Restored debug info is wrong!'
assert node_2['fw_tensor_debug_info'] == [('ghi,jkl', 0, 'ghi,jkl')], 'Restored debug info is wrong!'
assert node_3['fw_tensor_debug_info'] == [('mno', 0, 'mno'), ('pqr,stu', 0, 'pqr,stu')],\
assert node_1['fw_tensor_debug_info'] == [('abc', 'abc'), ('def', 'def')], 'Restored debug info is wrong!'
assert node_2['fw_tensor_debug_info'] == [('ghi,jkl', 'ghi,jkl')], 'Restored debug info is wrong!'
assert node_3['fw_tensor_debug_info'] == [('mno', 'mno'), ('pqr,stu', 'pqr,stu')],\
'Restored debug info is wrong!'