Parameter/Result tensor names fix (#9754)

* Added tensor names checking transformation. Added framework name saving to tensor names list for input/output cut.

* Small fixes.

* Fixed tensor propagation in PowerToEltwises.

* Corrected tensor checking.

* Corrected tensor checking.

* Fixed MemoryOffsetAdjustment().

* Fixed tensor propagation for Yolo, ONNXMaskRCNNTransformation().

* Small fix.

* DetectionOutput tensor name set.

* Tensor name set for Reshape node in OD API.

* Temporarily added set of tensor names in ConvertGroupedStridedSlice.

* Small corrections, added tests.

* Added checks.

* Added deafault names setting.

* Moved default names setting to single place.

* Added port normilize befor setting tensor names.

* Fixed ResultRename logic.

* Removed tensor setting of unset ports.

* Corrected input cut tensor naming.

* Corrected InputCut, renamed set_tensor_names()->add_tensor_names().

* Fixed tensor setting for InputCut.

* Code corrections.
This commit is contained in:
Anastasia Popova 2022-01-26 17:23:46 +03:00 committed by GitHub
parent 8f94d6dd3f
commit 84f2e9fc24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 307 additions and 28 deletions

View File

@ -31,8 +31,21 @@ class FakeOutputResolver(BackReplacementPattern):
add = create_op_with_const_inputs(graph, Add, {1: int64_array(0)}, {'can_be_fused': False})
rename_nodes([(fake_output, name + '/TBD'), (add, name)])
# Get tensor names incoming to FakeOutput
tensor_names = fake_output.in_port(0).get_connection().source.get_tensor_names()
# Remove tensor info from data node
in_data_node = fake_output.in_node()
if 'fw_tensor_debug_info' in in_data_node:
del in_data_node['fw_tensor_debug_info']
fake_output.in_port(0).get_connection().set_destination(add.in_port(0))
fake_output.out_port(0).get_connection().set_source(add.out_port(0))
# Move tensor names to Add op, which replaces FakeOutput
if len(tensor_names) > 0:
add.out_port(0).add_tensor_names([add.name], [tensor_names])
else:
result_in_port = fake_output.out_port(0).get_destination()
result_in_port.disconnect()

View File

@ -13,6 +13,11 @@ class ResultRename(BackReplacementPattern):
enabled = False
def find_and_replace_pattern(self, graph: Graph):
names_set = set()
for node in graph.get_op_nodes():
if node.has_valid('name'):
names_set.add(node['name'])
for node in graph.get_op_nodes(type='Result'):
if node.in_ports():
prev_node_out_port = node.in_port(0).get_connection().get_source()
@ -20,11 +25,24 @@ class ResultRename(BackReplacementPattern):
# Graph may contain Result nodes with names equal to input tensors and
# renaming in this case is not needed. The example of such situation is
# IR reader check when graph is read with correct Result names.
if tensor_names and node.soft_get('name') == tensor_names[0]:
if not tensor_names:
result_name = prev_node_out_port.node.soft_get('name', prev_node_out_port.node.id) + \
'/sink_port_' + str(prev_node_out_port.idx)
node['name'] = result_name
continue
if tensor_names and not graph.get_op_nodes(name=tensor_names[0]):
result_name = tensor_names[0]
else:
# If Result name is equal to some tensor name from list, then renaming is not needed
if node.soft_get('name') in tensor_names:
continue
# Try to find tensor name, that is not intersects with graph node names
result_name = None
for tensor_name in tensor_names:
if tensor_name not in names_set:
result_name = tensor_name
break
# If we didn't find appropriate tensor name, then Result is named by default naming
if result_name is None:
result_name = prev_node_out_port.node.soft_get('name', prev_node_out_port.node.id) + \
'/sink_port_' + str(prev_node_out_port.idx)
node['name'] = result_name

View File

@ -21,21 +21,21 @@ class PowerToEltwises(FrontReplacementOp):
const = Const(graph, {'value': mo_array(op.scale)}).create_node()
mul = Mul(graph, {'name': op.name + '/mul_'}).create_node()
const.out_port(0).connect(mul.in_port(1))
out_port.connect(mul.in_port(0))
mul.in_port(0).get_connection().set_source(out_port)
out_port = mul.out_port(0)
if op.soft_get('shift', 0) != 0:
const = Const(graph, {'value': mo_array(op.shift)}).create_node()
add = Add(graph, {'name': op.name + '/add_'}).create_node()
const.out_port(0).connect(add.in_port(1))
out_port.connect(add.in_port(0))
add.in_port(0).get_connection().set_source(out_port)
out_port = add.out_port(0)
if op.soft_get('power', 1) != 1:
const = Const(graph, {'value': mo_array(op.power)}).create_node()
pow = Pow(graph, {'name': op.name + '/pow_'}).create_node()
const.out_port(0).connect(pow.in_port(1))
out_port.connect(pow.in_port(0))
pow.in_port(0).get_connection().set_source(out_port)
out_port = pow.out_port(0)
op.out_port(0).get_connection().set_source(out_port)

View File

@ -780,7 +780,7 @@ def add_output_ops(graph: Graph, user_defined_outputs: dict, inputs: dict = None
refer_to_faq_msg(29), value['in'], node)
for u, v, attrs in in_edges:
if 'in' in attrs and attrs['in'] == value['in']:
sinks.append(add_opoutput(graph, u, attrs['out']))
sinks.append(add_opoutput(graph, u, attrs['out'], user_defined_name=node))
elif 'out' in value:
out_edges = list(graph.out_edges(node, data=True))
if len(out_edges) - 1 < value['out']:
@ -788,9 +788,9 @@ def add_output_ops(graph: Graph, user_defined_outputs: dict, inputs: dict = None
refer_to_faq_msg(29), value['out'], node)
for u, v, attrs in out_edges:
if 'out' in attrs and attrs['out'] == value['out']:
sinks.append(add_opoutput(graph, node, attrs['out']))
sinks.append(add_opoutput(graph, node, attrs['out'], user_defined_name=node))
else:
sinks.append(add_opoutput(graph, node, 0))
sinks.append(add_opoutput(graph, node, 0, user_defined_name=node))
return sinks
@ -870,13 +870,16 @@ def add_input_op_input_port_with_data(graph: Graph, node_id: str, input_op, edge
out_port.data.set_shape(input_node.soft_get('shape', None))
input_data_node = input_node.out_node(0)
if 'fw_tensor_debug_info' in edge_attrs:
input_data_node['fw_tensor_debug_info'] = edge_attrs['fw_tensor_debug_info']
log.debug('Input: {} for node {}'.format(input_node.id, node_id))
log.debug("Add edge from {} to {}".format(input_node.id, input_data_node.id))
log.debug("Add edge from {} to {}".format(input_data_node.id, node_id))
return input_node.id
def add_input_op_output_port_without_data(graph: Graph, node_id: str, input_op, port: int):
def add_input_op_output_port_without_data(graph: Graph, node_id: str, input_op, port: int, fw_info: list):
input_node = input_op.create_node()
# In this case it can be more than one out edge from one port and we should iterate over all output edges
for _, out_node, attrs in graph.out_edges(node_id, data=True):
@ -884,17 +887,20 @@ def add_input_op_output_port_without_data(graph: Graph, node_id: str, input_op,
# new out port = 0
attrs = attrs.copy()
attrs['out'] = 0
attrs['fw_tensor_debug_info'] = fw_info
attrs['data_attrs'] = ['fw_tensor_debug_info']
graph.add_edge(input_node.id, out_node, **attrs)
log.debug('Input: {} for node {} output port {}'.format(input_node.id, node_id, port))
log.debug("Add edge from {} to {}".format(input_node.id, out_node))
return input_node.id
def add_input_op_output_port_with_data(graph: Graph, node_id: str, input_op, port: int):
def add_input_op_output_port_with_data(graph: Graph, node_id: str, input_op, port: int, fw_info: list):
# we assume that after op always data node
assert graph.stage == 'middle', 'add_input_op_input_port_with_data() function can be used only for graph after ' \
'shape inference!'
data_node = Node(graph, node_id).out_node(port)
data_node['fw_tensor_debug_info'] = fw_info
assert data_node.has_valid('kind') and data_node.kind == 'data'
input_node = input_op.create_node()
Node(graph, node_id).out_port(port).get_connection().set_source(input_node.out_port(0))
@ -924,21 +930,34 @@ def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False,
input_op = Parameter(graph, dict(shape=shape, user_shape=user_shape, data_type=data_type, initial_node_name=node_id,
name=get_new_placeholder_name(node_id, is_out_port, port)))
fw_name = Node(graph, node_id).soft_get('name')
tensor_name = Node(graph, node_id).soft_get('name') + ":" + str(port)
fw_info = [(Node(graph, node_id).soft_get('name'), tensor_name)]
if is_out_port and port == 0:
tensor_name_no_port = Node(graph, node_id).soft_get('name')
# TODO: This can be optimized. Tensor names can be stored as set, which is initialized after model loading.
graph_tensor_names = graph.get_tensor_names_set()
if tensor_name_no_port in graph_tensor_names:
log.warning('Could not add user defined input name {} to tensor names list of as '
'graph contains tensor name with same name.'.format(tensor_name_no_port))
else:
# Add alias with operation name, as this format is used in some config files
fw_info.append((Node(graph, node_id).soft_get('name'), tensor_name_no_port))
edge_attrs = {'in': port, 'out': 0, 'in_attrs': ['in'], 'out_attrs': ['out'],
'fw_tensor_debug_info': [(fw_name, fw_name)],
'fw_tensor_debug_info': fw_info,
'data_attrs': ['fw_tensor_debug_info']}
if not data:
if is_out_port:
new_input_id = add_input_op_output_port_without_data(graph=graph, node_id=node_id, input_op=input_op,
port=port)
port=port, fw_info=edge_attrs['fw_tensor_debug_info'])
else:
new_input_id = add_input_op_input_port_without_data(graph=graph, node_id=node_id, input_op=input_op,
edge_attrs=edge_attrs)
else:
if is_out_port:
new_input_id = add_input_op_output_port_with_data(graph=graph, node_id=node_id, input_op=input_op,
port=port)
port=port, fw_info=edge_attrs['fw_tensor_debug_info'])
else:
new_input_id = add_input_op_input_port_with_data(graph=graph, node_id=node_id, input_op=input_op,
edge_attrs=edge_attrs)
@ -1064,6 +1083,16 @@ def add_input_ops(graph: Graph, user_defined_inputs: dict, before_infer: bool):
smart_node['data_type'] = data_type
inputs.append(node_id)
port_and_shape_info['added'] = True
if smart_node.out_edges():
# User specified input is Parameter, so input cut is not needed, but
# Op name needs to be added to tensor names
fw_info = []
op_name = smart_node.soft_get('name')
if 'fw_tensor_debug_info' in smart_node.out_edge(0):
fw_info += smart_node.out_edge(0)['fw_tensor_debug_info']
smart_node.out_edge(0)['fw_tensor_debug_info'] = fw_info + [(op_name, op_name)]
continue
if before_infer:

View File

@ -54,8 +54,9 @@ def align_frame_time(graph: Graph, node: Node, frame_time_max):
# add element_size for MemoryOffset after Parameter for infer
if in_node.op == 'Parameter':
memory_align['element_size'] = in_node.shape
memory_align.in_port(0).get_connection().set_source(in_node_out_port)
in_port.get_connection().set_source(memory_align.out_port(0))
memory_align.in_port(0).connect(in_node_out_port)
memory_align['frame_time'] = memory_align.t
# remove MemoryOffset with maximum delay
elif in_node.frame_time == frame_time_max and in_node.op == 'MemoryOffset':

View File

@ -62,7 +62,7 @@ class Connection:
def get_destinations(self):
return self.destinations
def set_source(self, port, attributes_save_mode: str = "merge"):
def set_source(self, port, attributes_save_mode=None):
# In this method we are changing source for a connection with given port.
# See detailed example below.
#
@ -99,6 +99,16 @@ class Connection:
if self.control_flow is True:
raise Error("Cannot operate with connection with control_flow=True")
if attributes_save_mode is None:
attributes_save_mode = "merge"
if self.source is not None:
scr_node = self.source.node
# Force "source" mode for "Parameter" source node, which preserves tensor names for
# source node in connection.
if scr_node.soft_get("type") == "Parameter":
attributes_save_mode = "source"
if self.graph.stage == 'front':
scr_node = port.node
@ -161,7 +171,7 @@ class Connection:
else:
self.graph.add_edge(port_out_data.id, dst_port.node.id, **{'in': dst_port.idx})
def set_destination(self, port, attributes_save_mode: str = "merge"):
def set_destination(self, port, attributes_save_mode=None):
# In this method we are changing destination for a connection with given port with type 'in'.
# This method requires exactly one destination or empty destinations list.
# See detailed example below.
@ -212,6 +222,16 @@ class Connection:
if self.control_flow is True:
raise Error("Cannot operate with connection with control_flow=True")
if attributes_save_mode is None:
attributes_save_mode = "merge"
if self.source is not None:
scr_node = self.source.node
# Force "source" mode for "Parameter" source node, which preserves tensor names for
# source node in connection.
if scr_node.soft_get("type") == "Parameter":
attributes_save_mode = "source"
if self.graph.stage == 'front':
if self.source is not None:
node = self.source.node

View File

@ -1003,6 +1003,24 @@ class Graph(nx.MultiDiGraph):
# Add Const op for constant data nodes
add_constant_operations(self)
def get_tensor_names_set(self, use_ports = False):
"""
Get set of tensor names of the graph.
"""
tensor_names_set = set()
for node in self.get_op_nodes():
if self.stage is None:
for out_edge_idx in node.out_edges():
out_edge = node.out_edge(out_edge_idx)
if "fw_tensor_debug_info" in out_edge:
for _, tensor_name in out_edge["fw_tensor_debug_info"]:
tensor_names_set.add(tensor_name)
else:
for _, port in node.out_ports().items():
tensor_names = port.get_tensor_names()
tensor_names_set = tensor_names_set.union(set(tensor_names))
return tensor_names_set
def topological_sort(self, reverse: bool = False):
sorted_node_ids = nx.topological_sort(self)
@ -1051,7 +1069,8 @@ def dict_includes(big: dict, sub_dict: dict, skip_attr_names=[]):
)
def add_opoutput(graph: Graph, node_name: str, port: int, cut: bool = True, keep_output_port: bool = False):
def add_opoutput(graph: Graph, node_name: str, port: int, cut: bool = True, keep_output_port: bool = False,
user_defined_name=None):
"""
Creates and connects Result node to node_name port. Cuts existing port if requested.
:param graph: graph to operate with
@ -1059,6 +1078,7 @@ def add_opoutput(graph: Graph, node_name: str, port: int, cut: bool = True, keep
:param port: output port of node to connect Result to
:param cut: determines way of operating with edge specified by node_name and port
:param keep_output_port: special attribute determines if this operation is saved in IR or not
:param user_defined_name: User defined operation name, which should be added to tensor names list
"""
# we import it here because Op imports add_attrs_props and update_ie_fields from this file
from openvino.tools.mo.ops.result import Result
@ -1071,6 +1091,26 @@ def add_opoutput(graph: Graph, node_name: str, port: int, cut: bool = True, keep
'keep_output_port': keep_output_port})
opoutput_node.in_edge()['data_attrs'] = ['fw_tensor_debug_info']
if user_defined_name is not None and (graph.stage == 'front' or graph.stage is None):
# Following code adds user_defined_name to tensor names list
# Not applicable for middle stage
prev_op_tensor_names = set()
in_edge_attrs = opoutput_node.in_edge()
if 'fw_tensor_debug_info' in in_edge_attrs:
for _, tensor_name in opoutput_node.in_edge()['fw_tensor_debug_info']:
prev_op_tensor_names.add(tensor_name)
if user_defined_name not in prev_op_tensor_names:
# TODO: This can be optimized. Tensor names can be stored as set, which is initialized after model loading.
graph_tensor_names = graph.get_tensor_names_set()
if user_defined_name in graph_tensor_names:
log.warning('Could not add user defined output name {} to tensor names list of {} node as '
'graph contains tensor name with same name.'.format(user_defined_name,
opoutput_node.soft_get('name')))
else:
if 'fw_tensor_debug_info' not in in_edge_attrs:
in_edge_attrs['fw_tensor_debug_info'] = []
in_edge_attrs['fw_tensor_debug_info'].append([user_defined_name, user_defined_name])
log.debug('Sink: {} for node {}'.format(opoutput_node.id, node_name))
log.debug(str(graph.node[opoutput_node.id]))
log.debug("Add edge from {} to {}".format(node_name, opoutput_node.id))

View File

@ -274,6 +274,29 @@ class Port:
tensor_names_list.append(tensor_name.replace(',', '\\,'))
return sorted(tensor_names_list)
def add_tensor_names(self, op_names: list, tensor_names: list, port_renumber: bool = False):
"""
Sets tensor names list.
:param op_names: list of op names.
:param tensor_names: list of lists of tensor names.
:param port_renumber: defines whether data node index should be calculated considering port renumbering.
"""
assert len(tensor_names) == len(op_names), \
"Number of tensor name lists should be the same as number of operation name."
if len(tensor_names) == 0:
return
new_debug_items = []
for tensor_names, op_name in zip(tensor_names, op_names):
assert isinstance(tensor_names, list), "Tensor names elements should be lists with strings."
new_debug_items.append((op_name, ','.join(tensor_names)))
tensor_debug_info = self.get_tensor_debug_info(port_renumber)
tensor_debug_info += new_debug_items
self.set_tensor_debug_info(tensor_debug_info, port_renumber)
def get_tensor_debug_info(self, port_renumber: bool = False):
"""
Gets tensor debug info attribute.
@ -306,6 +329,31 @@ class Port:
fw_debug_info += get_tensor_debug_info_from_attrs(out_node.attrs())
return fw_debug_info
def set_tensor_debug_info(self, tensor_info: list, port_renumber: bool = False):
"""
Gets tensor debug info attribute.
:param tensor_info: new tensor debug info value.
:param port_renumber: defines whether data node index should be calculated considering port renumbering.
"""
assert self.type != 'in', "Can't get tensor debug info for input port at {} node".format(self.node.name)
if self.node.graph.stage == 'front':
if self.idx in self.node.out_edges():
out_edge = self.node.out_edge(self.idx)
out_edge['fw_tensor_debug_info'] = tensor_info
else:
# before port renumbering we use sequential numbering
node_idx = self.idx
if port_renumber:
if self.node.type != 'Const':
# after port renumbering port indices start from zero,
# but data node indices remain the same
node_idx = self.idx + len(self.node.in_nodes())
if node_idx in self.node.out_nodes():
out_node = self.node.out_node(node_idx)
out_node['fw_tensor_debug_info'] = tensor_info
def disconnect(self):
if self.type == 'out':
@ -416,3 +464,12 @@ class Port:
raise Error('Trying to override data type for output port {} of operation {}: from {} to {}'.format(
self.idx, node.name, node._out_port_data_type[self.idx], data_type))
node._out_port_data_type[self.idx] = data_type
def get_default_tensor_name(self):
"""
Gets default_tensor_name
:return: tensor name
"""
if self.type == 'in':
return None
return self.node.soft_get('name', self.node.id) + ":" + str(self.idx)

View File

@ -142,7 +142,10 @@ def _fuse_mul(graph: Graph, node: Node, fuse_nodes: list, backward: bool = True)
const_port.disconnect()
# as Mul node is added before convolution, output tensor from Convolution node
# corresponds to original Mul node
node.out_port(0).get_connection().set_source(producer_port, "dest")
if producer_port.node.soft_get('type') == 'Parameter':
node.out_port(0).get_connection().set_source(producer_port, "source")
else:
node.out_port(0).get_connection().set_source(producer_port, "dest")
return is_fused

View File

@ -10,14 +10,17 @@ import networkx as nx
from openvino.tools.mo.back.RemoveUselessConvert import RemoveUselessConvert
from openvino.tools.mo.back.ResultRename import ResultRename
from openvino.tools.mo.back.ie_ir_ver_2.emitter import port_renumber, serialize_constants, generate_ie_ir, \
serialize_mean_image
from openvino.tools.mo.back.op_versioning import OpVersioning
from openvino.tools.mo.ops.Cast import Cast
from openvino.tools.mo.back.ie_ir_ver_2.emitter import port_renumber, serialize_constants, generate_ie_ir, serialize_mean_image
from openvino.tools.mo.graph.graph import Node, Graph
from openvino.tools.mo.middle.passes import tensor_names, convert_data_type
from openvino.tools.mo.middle.passes.convert_data_type import data_type_str_to_np
from openvino.tools.mo.middle.passes.eliminate import shape_inference
from openvino.tools.mo.middle.passes.infer import type_infer
from openvino.tools.mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
from openvino.tools.mo.ops.Cast import Cast
from openvino.tools.mo.ops.op import Op
from openvino.tools.mo.utils.error import Error
@ -171,6 +174,25 @@ def convert_inputs_of_specific_ops(graph: Graph):
in_port.get_connection().insert_node(Cast(graph, {'dst_type': np_type}).create_node())
def set_default_tensor_names_for_parameters_results(graph: Graph):
for node in graph.get_op_nodes():
if node.soft_get('type') == 'Result' and node.is_in_port_connected(0):
port = node.in_port(0).get_connection().get_source()
elif node.soft_get('type') == 'Parameter' and node.is_out_port_connected(0):
port = node.out_port(0)
else:
continue
if node.has_and_set('keep_output_port'):
continue
tensors = port.get_tensor_names()
if tensors is not None and isinstance(tensors, list) and len(tensors) > 0:
continue
new_tensor_name = port.get_default_tensor_name()
op_name = port.node.soft_get('name')
port.add_tensor_names([op_name], [[new_tensor_name]])
def prepare_emit_ir(graph: Graph, data_type: str, output_dir: str, output_model_name: str,
mean_data: [list, None] = None, input_names: list = None, meta_info: dict = None,
use_temporary_path=False, convert_types=False):
@ -199,6 +221,7 @@ def prepare_emit_ir(graph: Graph, data_type: str, output_dir: str, output_model_
for_graph_and_each_sub_graph_recursively(graph, RemoveUselessConvert().find_and_replace_pattern)
ResultRename().find_and_replace_pattern(graph)
set_default_tensor_names_for_parameters_results(graph)
for sub_graph in [graph] + collect_sub_graphs(graph):
op_order, data_order = determined_sort(get_sorted_outputs(sub_graph))

View File

@ -32,6 +32,12 @@ class TestsGetTensorNames(unittest.TestCase):
op1_node.add_output_port(0)
self.assertTrue(op1_node.out_port(0).get_tensor_names() == [])
input_node.out_port(0).add_tensor_names(["A", "B", "C"], [["A:0"], ["B:0", "B:1", "B:2"], ["C:0"]])
self.assertTrue(input_node.out_port(0).get_tensor_debug_info() ==
[('input', 'input'), ('Op1', 'Op1,Op2'), ("A", "A:0"), ("B", "B:0,B:1,B:2"), ("C", "C:0")])
self.assertTrue(input_node.out_port(0).get_tensor_names() ==
['A:0', 'B:0\\,B:1\\,B:2', 'C:0', 'Op1\\,Op2', 'input'])
def test_middle(self):
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'),
('input_data', 'Op2')])
@ -47,6 +53,12 @@ class TestsGetTensorNames(unittest.TestCase):
op2_node.add_output_port(0)
self.assertTrue(op2_node.out_port(0).get_tensor_names() == [])
input_node.out_port(0).add_tensor_names(["A", "B", "C"], [["A:0"], ["B:0", "B:1", "B:2"], ["C:0"]])
self.assertTrue(input_node.out_port(0).get_tensor_debug_info() ==
[('input', 'input'), ('Op1', 'Op1,Op2'), ("A", "A:0"), ("B", "B:0,B:1,B:2"), ("C", "C:0")])
self.assertTrue(input_node.out_port(0).get_tensor_names() ==
['A:0', 'B:0\\,B:1\\,B:2', 'C:0', 'Op1\\,Op2', 'input'])
def test_port_renumber(self):
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'),
('Op1', 'Op1_data', {'out': 1}), ('Op1_data', 'Op2')])
@ -58,6 +70,13 @@ class TestsGetTensorNames(unittest.TestCase):
self.assertTrue(op1_node.out_port(0).get_tensor_names(port_renumber=True) == ['Op1\\,Op2'])
input_node.out_port(0).add_tensor_names(["A", "B", "C"], [["A:0"], ["B:0", "B:1", "B:2"], ["C:0"]],
port_renumber=True)
self.assertTrue(input_node.out_port(0).get_tensor_debug_info(port_renumber=True) ==
[('input', 'input'), ('Op1', 'Op1,Op2'), ("A", "A:0"), ("B", "B:0,B:1,B:2"), ("C", "C:0")])
self.assertTrue(input_node.out_port(0).get_tensor_names(port_renumber=True) ==
['A:0', 'B:0\\,B:1\\,B:2', 'C:0', 'Op1\\,Op2', 'input'])
def test_reconnect_middle_case1(self):
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'), ('Op3', 'Op3_data')])
input_node = Node(graph, 'input')
@ -65,11 +84,24 @@ class TestsGetTensorNames(unittest.TestCase):
input_node_out_port = input_node.out_port(0)
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
op3_node = Node(graph, 'Op3')
input_node_out_port.get_connection().set_source(op3_node.out_port(0), "merge")
self.assertTrue(input_node_out_port.get_tensor_names() is None)
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op1\\,Op2', 'Op3', 'input'])
def test_reconnect_middle_case1_parameter(self):
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'), ('Op3', 'Op3_data')])
input_node = Node(graph, 'input')
input_node_out_port = input_node.out_port(0)
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
op3_node = Node(graph, 'Op3')
input_node_out_port.get_connection().set_source(op3_node.out_port(0))
self.assertTrue(input_node_out_port.get_tensor_names() is None)
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op1\\,Op2', 'Op3', 'input'])
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op3'])
def test_reconnect_front_case1(self):
graph = build_graph(nodes, [('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input'),
@ -81,11 +113,27 @@ class TestsGetTensorNames(unittest.TestCase):
input_node_out_port = input_node.out_port(0)
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
op3_node = Node(graph, 'Op3')
input_node_out_port.get_connection().set_source(op3_node.out_port(0), "merge")
self.assertTrue(input_node_out_port.get_tensor_names() == [])
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op1\\,Op2', 'Op3', 'input'])
def test_reconnect_front_case1_parameter(self):
graph = build_graph(nodes, [('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input'),
('Op1', 'Op1,Op2')]}),
('Op3', 'Op2', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('Op3', 'Op3')]})])
graph.stage = 'front'
input_node = Node(graph, 'input')
input_node_out_port = input_node.out_port(0)
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
op3_node = Node(graph, 'Op3')
input_node_out_port.get_connection().set_source(op3_node.out_port(0))
self.assertTrue(input_node_out_port.get_tensor_names() == [])
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op1\\,Op2', 'Op3', 'input'])
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op3'])
def test_reconnect_middle_case1(self):
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'), ('Op3', 'Op3_data')])
@ -94,9 +142,36 @@ class TestsGetTensorNames(unittest.TestCase):
input_node_out_port = input_node.out_port(0)
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
op3_node = Node(graph, 'Op3')
input_node_out_port.get_connection().set_source(op3_node.out_port(0), "merge")
self.assertTrue(input_node_out_port.get_tensor_names() == [])
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op1\\,Op2', 'Op3', 'input'])
def test_reconnect_middle_case1_parameter(self):
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'), ('Op3', 'Op3_data')])
input_node = Node(graph, 'input')
input_node_out_port = input_node.out_port(0)
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
op3_node = Node(graph, 'Op3')
input_node_out_port.get_connection().set_source(op3_node.out_port(0))
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op3'])
def test_reconnect_middle_case2(self):
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1', {'out': 0}),
('input_data', 'Op1', {'out': 1}), ('Op3', 'Op3_data')])
input_node = Node(graph, 'input')
input_node_out_port = input_node.out_port(0)
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
op3_node = Node(graph, 'Op3')
input_node_out_port.get_connection().set_source(op3_node.out_port(0), "merge")
self.assertTrue(input_node_out_port.get_tensor_names() == [])
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op1\\,Op2', 'Op3', 'input'])
@ -111,8 +186,8 @@ class TestsGetTensorNames(unittest.TestCase):
op3_node = Node(graph, 'Op3')
input_node_out_port.get_connection().set_source(op3_node.out_port(0))
self.assertTrue(input_node_out_port.get_tensor_names() == [])
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op1\\,Op2', 'Op3', 'input'])
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op3'])
class TestPortMethods(unittest.TestCase):