Parameter/Result tensor names fix (#9754)
* Added tensor names checking transformation. Added framework name saving to tensor names list for input/output cut. * Small fixes. * Fixed tensor propagation in PowerToEltwises. * Corrected tensor checking. * Corrected tensor checking. * Fixed MemoryOffsetAdjustment(). * Fixed tensor propagation for Yolo, ONNXMaskRCNNTransformation(). * Small fix. * DetectionOutput tensor name set. * Tensor name set for Reshape node in OD API. * Temporarily added set of tensor names in ConvertGroupedStridedSlice. * Small corrections, added tests. * Added checks. * Added deafault names setting. * Moved default names setting to single place. * Added port normilize befor setting tensor names. * Fixed ResultRename logic. * Removed tensor setting of unset ports. * Corrected input cut tensor naming. * Corrected InputCut, renamed set_tensor_names()->add_tensor_names(). * Fixed tensor setting for InputCut. * Code corrections.
This commit is contained in:
parent
8f94d6dd3f
commit
84f2e9fc24
@ -31,8 +31,21 @@ class FakeOutputResolver(BackReplacementPattern):
|
|||||||
add = create_op_with_const_inputs(graph, Add, {1: int64_array(0)}, {'can_be_fused': False})
|
add = create_op_with_const_inputs(graph, Add, {1: int64_array(0)}, {'can_be_fused': False})
|
||||||
rename_nodes([(fake_output, name + '/TBD'), (add, name)])
|
rename_nodes([(fake_output, name + '/TBD'), (add, name)])
|
||||||
|
|
||||||
|
# Get tensor names incoming to FakeOutput
|
||||||
|
tensor_names = fake_output.in_port(0).get_connection().source.get_tensor_names()
|
||||||
|
|
||||||
|
# Remove tensor info from data node
|
||||||
|
in_data_node = fake_output.in_node()
|
||||||
|
if 'fw_tensor_debug_info' in in_data_node:
|
||||||
|
del in_data_node['fw_tensor_debug_info']
|
||||||
|
|
||||||
fake_output.in_port(0).get_connection().set_destination(add.in_port(0))
|
fake_output.in_port(0).get_connection().set_destination(add.in_port(0))
|
||||||
fake_output.out_port(0).get_connection().set_source(add.out_port(0))
|
fake_output.out_port(0).get_connection().set_source(add.out_port(0))
|
||||||
|
|
||||||
|
# Move tensor names to Add op, which replaces FakeOutput
|
||||||
|
if len(tensor_names) > 0:
|
||||||
|
add.out_port(0).add_tensor_names([add.name], [tensor_names])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
result_in_port = fake_output.out_port(0).get_destination()
|
result_in_port = fake_output.out_port(0).get_destination()
|
||||||
result_in_port.disconnect()
|
result_in_port.disconnect()
|
||||||
|
@ -13,6 +13,11 @@ class ResultRename(BackReplacementPattern):
|
|||||||
enabled = False
|
enabled = False
|
||||||
|
|
||||||
def find_and_replace_pattern(self, graph: Graph):
|
def find_and_replace_pattern(self, graph: Graph):
|
||||||
|
names_set = set()
|
||||||
|
for node in graph.get_op_nodes():
|
||||||
|
if node.has_valid('name'):
|
||||||
|
names_set.add(node['name'])
|
||||||
|
|
||||||
for node in graph.get_op_nodes(type='Result'):
|
for node in graph.get_op_nodes(type='Result'):
|
||||||
if node.in_ports():
|
if node.in_ports():
|
||||||
prev_node_out_port = node.in_port(0).get_connection().get_source()
|
prev_node_out_port = node.in_port(0).get_connection().get_source()
|
||||||
@ -20,11 +25,24 @@ class ResultRename(BackReplacementPattern):
|
|||||||
# Graph may contain Result nodes with names equal to input tensors and
|
# Graph may contain Result nodes with names equal to input tensors and
|
||||||
# renaming in this case is not needed. The example of such situation is
|
# renaming in this case is not needed. The example of such situation is
|
||||||
# IR reader check when graph is read with correct Result names.
|
# IR reader check when graph is read with correct Result names.
|
||||||
if tensor_names and node.soft_get('name') == tensor_names[0]:
|
if not tensor_names:
|
||||||
continue
|
result_name = prev_node_out_port.node.soft_get('name', prev_node_out_port.node.id) + \
|
||||||
if tensor_names and not graph.get_op_nodes(name=tensor_names[0]):
|
'/sink_port_' + str(prev_node_out_port.idx)
|
||||||
result_name = tensor_names[0]
|
node['name'] = result_name
|
||||||
else:
|
continue
|
||||||
|
# If Result name is equal to some tensor name from list, then renaming is not needed
|
||||||
|
if node.soft_get('name') in tensor_names:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Try to find tensor name, that is not intersects with graph node names
|
||||||
|
result_name = None
|
||||||
|
for tensor_name in tensor_names:
|
||||||
|
if tensor_name not in names_set:
|
||||||
|
result_name = tensor_name
|
||||||
|
break
|
||||||
|
|
||||||
|
# If we didn't find appropriate tensor name, then Result is named by default naming
|
||||||
|
if result_name is None:
|
||||||
result_name = prev_node_out_port.node.soft_get('name', prev_node_out_port.node.id) + \
|
result_name = prev_node_out_port.node.soft_get('name', prev_node_out_port.node.id) + \
|
||||||
'/sink_port_' + str(prev_node_out_port.idx)
|
'/sink_port_' + str(prev_node_out_port.idx)
|
||||||
node['name'] = result_name
|
node['name'] = result_name
|
||||||
|
@ -21,21 +21,21 @@ class PowerToEltwises(FrontReplacementOp):
|
|||||||
const = Const(graph, {'value': mo_array(op.scale)}).create_node()
|
const = Const(graph, {'value': mo_array(op.scale)}).create_node()
|
||||||
mul = Mul(graph, {'name': op.name + '/mul_'}).create_node()
|
mul = Mul(graph, {'name': op.name + '/mul_'}).create_node()
|
||||||
const.out_port(0).connect(mul.in_port(1))
|
const.out_port(0).connect(mul.in_port(1))
|
||||||
out_port.connect(mul.in_port(0))
|
mul.in_port(0).get_connection().set_source(out_port)
|
||||||
out_port = mul.out_port(0)
|
out_port = mul.out_port(0)
|
||||||
|
|
||||||
if op.soft_get('shift', 0) != 0:
|
if op.soft_get('shift', 0) != 0:
|
||||||
const = Const(graph, {'value': mo_array(op.shift)}).create_node()
|
const = Const(graph, {'value': mo_array(op.shift)}).create_node()
|
||||||
add = Add(graph, {'name': op.name + '/add_'}).create_node()
|
add = Add(graph, {'name': op.name + '/add_'}).create_node()
|
||||||
const.out_port(0).connect(add.in_port(1))
|
const.out_port(0).connect(add.in_port(1))
|
||||||
out_port.connect(add.in_port(0))
|
add.in_port(0).get_connection().set_source(out_port)
|
||||||
out_port = add.out_port(0)
|
out_port = add.out_port(0)
|
||||||
|
|
||||||
if op.soft_get('power', 1) != 1:
|
if op.soft_get('power', 1) != 1:
|
||||||
const = Const(graph, {'value': mo_array(op.power)}).create_node()
|
const = Const(graph, {'value': mo_array(op.power)}).create_node()
|
||||||
pow = Pow(graph, {'name': op.name + '/pow_'}).create_node()
|
pow = Pow(graph, {'name': op.name + '/pow_'}).create_node()
|
||||||
const.out_port(0).connect(pow.in_port(1))
|
const.out_port(0).connect(pow.in_port(1))
|
||||||
out_port.connect(pow.in_port(0))
|
pow.in_port(0).get_connection().set_source(out_port)
|
||||||
out_port = pow.out_port(0)
|
out_port = pow.out_port(0)
|
||||||
|
|
||||||
op.out_port(0).get_connection().set_source(out_port)
|
op.out_port(0).get_connection().set_source(out_port)
|
||||||
|
@ -780,7 +780,7 @@ def add_output_ops(graph: Graph, user_defined_outputs: dict, inputs: dict = None
|
|||||||
refer_to_faq_msg(29), value['in'], node)
|
refer_to_faq_msg(29), value['in'], node)
|
||||||
for u, v, attrs in in_edges:
|
for u, v, attrs in in_edges:
|
||||||
if 'in' in attrs and attrs['in'] == value['in']:
|
if 'in' in attrs and attrs['in'] == value['in']:
|
||||||
sinks.append(add_opoutput(graph, u, attrs['out']))
|
sinks.append(add_opoutput(graph, u, attrs['out'], user_defined_name=node))
|
||||||
elif 'out' in value:
|
elif 'out' in value:
|
||||||
out_edges = list(graph.out_edges(node, data=True))
|
out_edges = list(graph.out_edges(node, data=True))
|
||||||
if len(out_edges) - 1 < value['out']:
|
if len(out_edges) - 1 < value['out']:
|
||||||
@ -788,9 +788,9 @@ def add_output_ops(graph: Graph, user_defined_outputs: dict, inputs: dict = None
|
|||||||
refer_to_faq_msg(29), value['out'], node)
|
refer_to_faq_msg(29), value['out'], node)
|
||||||
for u, v, attrs in out_edges:
|
for u, v, attrs in out_edges:
|
||||||
if 'out' in attrs and attrs['out'] == value['out']:
|
if 'out' in attrs and attrs['out'] == value['out']:
|
||||||
sinks.append(add_opoutput(graph, node, attrs['out']))
|
sinks.append(add_opoutput(graph, node, attrs['out'], user_defined_name=node))
|
||||||
else:
|
else:
|
||||||
sinks.append(add_opoutput(graph, node, 0))
|
sinks.append(add_opoutput(graph, node, 0, user_defined_name=node))
|
||||||
return sinks
|
return sinks
|
||||||
|
|
||||||
|
|
||||||
@ -870,13 +870,16 @@ def add_input_op_input_port_with_data(graph: Graph, node_id: str, input_op, edge
|
|||||||
out_port.data.set_shape(input_node.soft_get('shape', None))
|
out_port.data.set_shape(input_node.soft_get('shape', None))
|
||||||
input_data_node = input_node.out_node(0)
|
input_data_node = input_node.out_node(0)
|
||||||
|
|
||||||
|
if 'fw_tensor_debug_info' in edge_attrs:
|
||||||
|
input_data_node['fw_tensor_debug_info'] = edge_attrs['fw_tensor_debug_info']
|
||||||
|
|
||||||
log.debug('Input: {} for node {}'.format(input_node.id, node_id))
|
log.debug('Input: {} for node {}'.format(input_node.id, node_id))
|
||||||
log.debug("Add edge from {} to {}".format(input_node.id, input_data_node.id))
|
log.debug("Add edge from {} to {}".format(input_node.id, input_data_node.id))
|
||||||
log.debug("Add edge from {} to {}".format(input_data_node.id, node_id))
|
log.debug("Add edge from {} to {}".format(input_data_node.id, node_id))
|
||||||
return input_node.id
|
return input_node.id
|
||||||
|
|
||||||
|
|
||||||
def add_input_op_output_port_without_data(graph: Graph, node_id: str, input_op, port: int):
|
def add_input_op_output_port_without_data(graph: Graph, node_id: str, input_op, port: int, fw_info: list):
|
||||||
input_node = input_op.create_node()
|
input_node = input_op.create_node()
|
||||||
# In this case it can be more than one out edge from one port and we should iterate over all output edges
|
# In this case it can be more than one out edge from one port and we should iterate over all output edges
|
||||||
for _, out_node, attrs in graph.out_edges(node_id, data=True):
|
for _, out_node, attrs in graph.out_edges(node_id, data=True):
|
||||||
@ -884,17 +887,20 @@ def add_input_op_output_port_without_data(graph: Graph, node_id: str, input_op,
|
|||||||
# new out port = 0
|
# new out port = 0
|
||||||
attrs = attrs.copy()
|
attrs = attrs.copy()
|
||||||
attrs['out'] = 0
|
attrs['out'] = 0
|
||||||
|
attrs['fw_tensor_debug_info'] = fw_info
|
||||||
|
attrs['data_attrs'] = ['fw_tensor_debug_info']
|
||||||
graph.add_edge(input_node.id, out_node, **attrs)
|
graph.add_edge(input_node.id, out_node, **attrs)
|
||||||
log.debug('Input: {} for node {} output port {}'.format(input_node.id, node_id, port))
|
log.debug('Input: {} for node {} output port {}'.format(input_node.id, node_id, port))
|
||||||
log.debug("Add edge from {} to {}".format(input_node.id, out_node))
|
log.debug("Add edge from {} to {}".format(input_node.id, out_node))
|
||||||
return input_node.id
|
return input_node.id
|
||||||
|
|
||||||
|
|
||||||
def add_input_op_output_port_with_data(graph: Graph, node_id: str, input_op, port: int):
|
def add_input_op_output_port_with_data(graph: Graph, node_id: str, input_op, port: int, fw_info: list):
|
||||||
# we assume that after op always data node
|
# we assume that after op always data node
|
||||||
assert graph.stage == 'middle', 'add_input_op_input_port_with_data() function can be used only for graph after ' \
|
assert graph.stage == 'middle', 'add_input_op_input_port_with_data() function can be used only for graph after ' \
|
||||||
'shape inference!'
|
'shape inference!'
|
||||||
data_node = Node(graph, node_id).out_node(port)
|
data_node = Node(graph, node_id).out_node(port)
|
||||||
|
data_node['fw_tensor_debug_info'] = fw_info
|
||||||
assert data_node.has_valid('kind') and data_node.kind == 'data'
|
assert data_node.has_valid('kind') and data_node.kind == 'data'
|
||||||
input_node = input_op.create_node()
|
input_node = input_op.create_node()
|
||||||
Node(graph, node_id).out_port(port).get_connection().set_source(input_node.out_port(0))
|
Node(graph, node_id).out_port(port).get_connection().set_source(input_node.out_port(0))
|
||||||
@ -924,21 +930,34 @@ def add_input_op(graph: Graph, node_id: str, port: int = 0, data: bool = False,
|
|||||||
input_op = Parameter(graph, dict(shape=shape, user_shape=user_shape, data_type=data_type, initial_node_name=node_id,
|
input_op = Parameter(graph, dict(shape=shape, user_shape=user_shape, data_type=data_type, initial_node_name=node_id,
|
||||||
name=get_new_placeholder_name(node_id, is_out_port, port)))
|
name=get_new_placeholder_name(node_id, is_out_port, port)))
|
||||||
|
|
||||||
fw_name = Node(graph, node_id).soft_get('name')
|
tensor_name = Node(graph, node_id).soft_get('name') + ":" + str(port)
|
||||||
|
fw_info = [(Node(graph, node_id).soft_get('name'), tensor_name)]
|
||||||
|
if is_out_port and port == 0:
|
||||||
|
tensor_name_no_port = Node(graph, node_id).soft_get('name')
|
||||||
|
# TODO: This can be optimized. Tensor names can be stored as set, which is initialized after model loading.
|
||||||
|
graph_tensor_names = graph.get_tensor_names_set()
|
||||||
|
if tensor_name_no_port in graph_tensor_names:
|
||||||
|
log.warning('Could not add user defined input name {} to tensor names list of as '
|
||||||
|
'graph contains tensor name with same name.'.format(tensor_name_no_port))
|
||||||
|
else:
|
||||||
|
# Add alias with operation name, as this format is used in some config files
|
||||||
|
fw_info.append((Node(graph, node_id).soft_get('name'), tensor_name_no_port))
|
||||||
|
|
||||||
edge_attrs = {'in': port, 'out': 0, 'in_attrs': ['in'], 'out_attrs': ['out'],
|
edge_attrs = {'in': port, 'out': 0, 'in_attrs': ['in'], 'out_attrs': ['out'],
|
||||||
'fw_tensor_debug_info': [(fw_name, fw_name)],
|
'fw_tensor_debug_info': fw_info,
|
||||||
'data_attrs': ['fw_tensor_debug_info']}
|
'data_attrs': ['fw_tensor_debug_info']}
|
||||||
|
|
||||||
if not data:
|
if not data:
|
||||||
if is_out_port:
|
if is_out_port:
|
||||||
new_input_id = add_input_op_output_port_without_data(graph=graph, node_id=node_id, input_op=input_op,
|
new_input_id = add_input_op_output_port_without_data(graph=graph, node_id=node_id, input_op=input_op,
|
||||||
port=port)
|
port=port, fw_info=edge_attrs['fw_tensor_debug_info'])
|
||||||
else:
|
else:
|
||||||
new_input_id = add_input_op_input_port_without_data(graph=graph, node_id=node_id, input_op=input_op,
|
new_input_id = add_input_op_input_port_without_data(graph=graph, node_id=node_id, input_op=input_op,
|
||||||
edge_attrs=edge_attrs)
|
edge_attrs=edge_attrs)
|
||||||
else:
|
else:
|
||||||
if is_out_port:
|
if is_out_port:
|
||||||
new_input_id = add_input_op_output_port_with_data(graph=graph, node_id=node_id, input_op=input_op,
|
new_input_id = add_input_op_output_port_with_data(graph=graph, node_id=node_id, input_op=input_op,
|
||||||
port=port)
|
port=port, fw_info=edge_attrs['fw_tensor_debug_info'])
|
||||||
else:
|
else:
|
||||||
new_input_id = add_input_op_input_port_with_data(graph=graph, node_id=node_id, input_op=input_op,
|
new_input_id = add_input_op_input_port_with_data(graph=graph, node_id=node_id, input_op=input_op,
|
||||||
edge_attrs=edge_attrs)
|
edge_attrs=edge_attrs)
|
||||||
@ -1064,6 +1083,16 @@ def add_input_ops(graph: Graph, user_defined_inputs: dict, before_infer: bool):
|
|||||||
smart_node['data_type'] = data_type
|
smart_node['data_type'] = data_type
|
||||||
inputs.append(node_id)
|
inputs.append(node_id)
|
||||||
port_and_shape_info['added'] = True
|
port_and_shape_info['added'] = True
|
||||||
|
|
||||||
|
if smart_node.out_edges():
|
||||||
|
# User specified input is Parameter, so input cut is not needed, but
|
||||||
|
# Op name needs to be added to tensor names
|
||||||
|
fw_info = []
|
||||||
|
op_name = smart_node.soft_get('name')
|
||||||
|
if 'fw_tensor_debug_info' in smart_node.out_edge(0):
|
||||||
|
fw_info += smart_node.out_edge(0)['fw_tensor_debug_info']
|
||||||
|
smart_node.out_edge(0)['fw_tensor_debug_info'] = fw_info + [(op_name, op_name)]
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if before_infer:
|
if before_infer:
|
||||||
|
@ -54,8 +54,9 @@ def align_frame_time(graph: Graph, node: Node, frame_time_max):
|
|||||||
# add element_size for MemoryOffset after Parameter for infer
|
# add element_size for MemoryOffset after Parameter for infer
|
||||||
if in_node.op == 'Parameter':
|
if in_node.op == 'Parameter':
|
||||||
memory_align['element_size'] = in_node.shape
|
memory_align['element_size'] = in_node.shape
|
||||||
|
|
||||||
|
memory_align.in_port(0).get_connection().set_source(in_node_out_port)
|
||||||
in_port.get_connection().set_source(memory_align.out_port(0))
|
in_port.get_connection().set_source(memory_align.out_port(0))
|
||||||
memory_align.in_port(0).connect(in_node_out_port)
|
|
||||||
memory_align['frame_time'] = memory_align.t
|
memory_align['frame_time'] = memory_align.t
|
||||||
# remove MemoryOffset with maximum delay
|
# remove MemoryOffset with maximum delay
|
||||||
elif in_node.frame_time == frame_time_max and in_node.op == 'MemoryOffset':
|
elif in_node.frame_time == frame_time_max and in_node.op == 'MemoryOffset':
|
||||||
|
@ -62,7 +62,7 @@ class Connection:
|
|||||||
def get_destinations(self):
|
def get_destinations(self):
|
||||||
return self.destinations
|
return self.destinations
|
||||||
|
|
||||||
def set_source(self, port, attributes_save_mode: str = "merge"):
|
def set_source(self, port, attributes_save_mode=None):
|
||||||
# In this method we are changing source for a connection with given port.
|
# In this method we are changing source for a connection with given port.
|
||||||
# See detailed example below.
|
# See detailed example below.
|
||||||
#
|
#
|
||||||
@ -99,6 +99,16 @@ class Connection:
|
|||||||
if self.control_flow is True:
|
if self.control_flow is True:
|
||||||
raise Error("Cannot operate with connection with control_flow=True")
|
raise Error("Cannot operate with connection with control_flow=True")
|
||||||
|
|
||||||
|
if attributes_save_mode is None:
|
||||||
|
attributes_save_mode = "merge"
|
||||||
|
if self.source is not None:
|
||||||
|
scr_node = self.source.node
|
||||||
|
|
||||||
|
# Force "source" mode for "Parameter" source node, which preserves tensor names for
|
||||||
|
# source node in connection.
|
||||||
|
if scr_node.soft_get("type") == "Parameter":
|
||||||
|
attributes_save_mode = "source"
|
||||||
|
|
||||||
if self.graph.stage == 'front':
|
if self.graph.stage == 'front':
|
||||||
scr_node = port.node
|
scr_node = port.node
|
||||||
|
|
||||||
@ -161,7 +171,7 @@ class Connection:
|
|||||||
else:
|
else:
|
||||||
self.graph.add_edge(port_out_data.id, dst_port.node.id, **{'in': dst_port.idx})
|
self.graph.add_edge(port_out_data.id, dst_port.node.id, **{'in': dst_port.idx})
|
||||||
|
|
||||||
def set_destination(self, port, attributes_save_mode: str = "merge"):
|
def set_destination(self, port, attributes_save_mode=None):
|
||||||
# In this method we are changing destination for a connection with given port with type 'in'.
|
# In this method we are changing destination for a connection with given port with type 'in'.
|
||||||
# This method requires exactly one destination or empty destinations list.
|
# This method requires exactly one destination or empty destinations list.
|
||||||
# See detailed example below.
|
# See detailed example below.
|
||||||
@ -212,6 +222,16 @@ class Connection:
|
|||||||
if self.control_flow is True:
|
if self.control_flow is True:
|
||||||
raise Error("Cannot operate with connection with control_flow=True")
|
raise Error("Cannot operate with connection with control_flow=True")
|
||||||
|
|
||||||
|
if attributes_save_mode is None:
|
||||||
|
attributes_save_mode = "merge"
|
||||||
|
if self.source is not None:
|
||||||
|
scr_node = self.source.node
|
||||||
|
|
||||||
|
# Force "source" mode for "Parameter" source node, which preserves tensor names for
|
||||||
|
# source node in connection.
|
||||||
|
if scr_node.soft_get("type") == "Parameter":
|
||||||
|
attributes_save_mode = "source"
|
||||||
|
|
||||||
if self.graph.stage == 'front':
|
if self.graph.stage == 'front':
|
||||||
if self.source is not None:
|
if self.source is not None:
|
||||||
node = self.source.node
|
node = self.source.node
|
||||||
|
@ -1003,6 +1003,24 @@ class Graph(nx.MultiDiGraph):
|
|||||||
# Add Const op for constant data nodes
|
# Add Const op for constant data nodes
|
||||||
add_constant_operations(self)
|
add_constant_operations(self)
|
||||||
|
|
||||||
|
def get_tensor_names_set(self, use_ports = False):
|
||||||
|
"""
|
||||||
|
Get set of tensor names of the graph.
|
||||||
|
"""
|
||||||
|
tensor_names_set = set()
|
||||||
|
for node in self.get_op_nodes():
|
||||||
|
if self.stage is None:
|
||||||
|
for out_edge_idx in node.out_edges():
|
||||||
|
out_edge = node.out_edge(out_edge_idx)
|
||||||
|
if "fw_tensor_debug_info" in out_edge:
|
||||||
|
for _, tensor_name in out_edge["fw_tensor_debug_info"]:
|
||||||
|
tensor_names_set.add(tensor_name)
|
||||||
|
else:
|
||||||
|
for _, port in node.out_ports().items():
|
||||||
|
tensor_names = port.get_tensor_names()
|
||||||
|
tensor_names_set = tensor_names_set.union(set(tensor_names))
|
||||||
|
return tensor_names_set
|
||||||
|
|
||||||
def topological_sort(self, reverse: bool = False):
|
def topological_sort(self, reverse: bool = False):
|
||||||
sorted_node_ids = nx.topological_sort(self)
|
sorted_node_ids = nx.topological_sort(self)
|
||||||
|
|
||||||
@ -1051,7 +1069,8 @@ def dict_includes(big: dict, sub_dict: dict, skip_attr_names=[]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_opoutput(graph: Graph, node_name: str, port: int, cut: bool = True, keep_output_port: bool = False):
|
def add_opoutput(graph: Graph, node_name: str, port: int, cut: bool = True, keep_output_port: bool = False,
|
||||||
|
user_defined_name=None):
|
||||||
"""
|
"""
|
||||||
Creates and connects Result node to node_name port. Cuts existing port if requested.
|
Creates and connects Result node to node_name port. Cuts existing port if requested.
|
||||||
:param graph: graph to operate with
|
:param graph: graph to operate with
|
||||||
@ -1059,6 +1078,7 @@ def add_opoutput(graph: Graph, node_name: str, port: int, cut: bool = True, keep
|
|||||||
:param port: output port of node to connect Result to
|
:param port: output port of node to connect Result to
|
||||||
:param cut: determines way of operating with edge specified by node_name and port
|
:param cut: determines way of operating with edge specified by node_name and port
|
||||||
:param keep_output_port: special attribute determines if this operation is saved in IR or not
|
:param keep_output_port: special attribute determines if this operation is saved in IR or not
|
||||||
|
:param user_defined_name: User defined operation name, which should be added to tensor names list
|
||||||
"""
|
"""
|
||||||
# we import it here because Op imports add_attrs_props and update_ie_fields from this file
|
# we import it here because Op imports add_attrs_props and update_ie_fields from this file
|
||||||
from openvino.tools.mo.ops.result import Result
|
from openvino.tools.mo.ops.result import Result
|
||||||
@ -1071,6 +1091,26 @@ def add_opoutput(graph: Graph, node_name: str, port: int, cut: bool = True, keep
|
|||||||
'keep_output_port': keep_output_port})
|
'keep_output_port': keep_output_port})
|
||||||
opoutput_node.in_edge()['data_attrs'] = ['fw_tensor_debug_info']
|
opoutput_node.in_edge()['data_attrs'] = ['fw_tensor_debug_info']
|
||||||
|
|
||||||
|
if user_defined_name is not None and (graph.stage == 'front' or graph.stage is None):
|
||||||
|
# Following code adds user_defined_name to tensor names list
|
||||||
|
# Not applicable for middle stage
|
||||||
|
prev_op_tensor_names = set()
|
||||||
|
in_edge_attrs = opoutput_node.in_edge()
|
||||||
|
if 'fw_tensor_debug_info' in in_edge_attrs:
|
||||||
|
for _, tensor_name in opoutput_node.in_edge()['fw_tensor_debug_info']:
|
||||||
|
prev_op_tensor_names.add(tensor_name)
|
||||||
|
if user_defined_name not in prev_op_tensor_names:
|
||||||
|
# TODO: This can be optimized. Tensor names can be stored as set, which is initialized after model loading.
|
||||||
|
graph_tensor_names = graph.get_tensor_names_set()
|
||||||
|
if user_defined_name in graph_tensor_names:
|
||||||
|
log.warning('Could not add user defined output name {} to tensor names list of {} node as '
|
||||||
|
'graph contains tensor name with same name.'.format(user_defined_name,
|
||||||
|
opoutput_node.soft_get('name')))
|
||||||
|
else:
|
||||||
|
if 'fw_tensor_debug_info' not in in_edge_attrs:
|
||||||
|
in_edge_attrs['fw_tensor_debug_info'] = []
|
||||||
|
in_edge_attrs['fw_tensor_debug_info'].append([user_defined_name, user_defined_name])
|
||||||
|
|
||||||
log.debug('Sink: {} for node {}'.format(opoutput_node.id, node_name))
|
log.debug('Sink: {} for node {}'.format(opoutput_node.id, node_name))
|
||||||
log.debug(str(graph.node[opoutput_node.id]))
|
log.debug(str(graph.node[opoutput_node.id]))
|
||||||
log.debug("Add edge from {} to {}".format(node_name, opoutput_node.id))
|
log.debug("Add edge from {} to {}".format(node_name, opoutput_node.id))
|
||||||
|
@ -274,6 +274,29 @@ class Port:
|
|||||||
tensor_names_list.append(tensor_name.replace(',', '\\,'))
|
tensor_names_list.append(tensor_name.replace(',', '\\,'))
|
||||||
return sorted(tensor_names_list)
|
return sorted(tensor_names_list)
|
||||||
|
|
||||||
|
def add_tensor_names(self, op_names: list, tensor_names: list, port_renumber: bool = False):
|
||||||
|
"""
|
||||||
|
Sets tensor names list.
|
||||||
|
:param op_names: list of op names.
|
||||||
|
:param tensor_names: list of lists of tensor names.
|
||||||
|
:param port_renumber: defines whether data node index should be calculated considering port renumbering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert len(tensor_names) == len(op_names), \
|
||||||
|
"Number of tensor name lists should be the same as number of operation name."
|
||||||
|
|
||||||
|
if len(tensor_names) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
new_debug_items = []
|
||||||
|
for tensor_names, op_name in zip(tensor_names, op_names):
|
||||||
|
assert isinstance(tensor_names, list), "Tensor names elements should be lists with strings."
|
||||||
|
new_debug_items.append((op_name, ','.join(tensor_names)))
|
||||||
|
|
||||||
|
tensor_debug_info = self.get_tensor_debug_info(port_renumber)
|
||||||
|
tensor_debug_info += new_debug_items
|
||||||
|
self.set_tensor_debug_info(tensor_debug_info, port_renumber)
|
||||||
|
|
||||||
def get_tensor_debug_info(self, port_renumber: bool = False):
|
def get_tensor_debug_info(self, port_renumber: bool = False):
|
||||||
"""
|
"""
|
||||||
Gets tensor debug info attribute.
|
Gets tensor debug info attribute.
|
||||||
@ -306,6 +329,31 @@ class Port:
|
|||||||
fw_debug_info += get_tensor_debug_info_from_attrs(out_node.attrs())
|
fw_debug_info += get_tensor_debug_info_from_attrs(out_node.attrs())
|
||||||
return fw_debug_info
|
return fw_debug_info
|
||||||
|
|
||||||
|
def set_tensor_debug_info(self, tensor_info: list, port_renumber: bool = False):
|
||||||
|
"""
|
||||||
|
Gets tensor debug info attribute.
|
||||||
|
:param tensor_info: new tensor debug info value.
|
||||||
|
:param port_renumber: defines whether data node index should be calculated considering port renumbering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert self.type != 'in', "Can't get tensor debug info for input port at {} node".format(self.node.name)
|
||||||
|
|
||||||
|
if self.node.graph.stage == 'front':
|
||||||
|
if self.idx in self.node.out_edges():
|
||||||
|
out_edge = self.node.out_edge(self.idx)
|
||||||
|
out_edge['fw_tensor_debug_info'] = tensor_info
|
||||||
|
else:
|
||||||
|
# before port renumbering we use sequential numbering
|
||||||
|
node_idx = self.idx
|
||||||
|
if port_renumber:
|
||||||
|
if self.node.type != 'Const':
|
||||||
|
# after port renumbering port indices start from zero,
|
||||||
|
# but data node indices remain the same
|
||||||
|
node_idx = self.idx + len(self.node.in_nodes())
|
||||||
|
|
||||||
|
if node_idx in self.node.out_nodes():
|
||||||
|
out_node = self.node.out_node(node_idx)
|
||||||
|
out_node['fw_tensor_debug_info'] = tensor_info
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self):
|
||||||
if self.type == 'out':
|
if self.type == 'out':
|
||||||
@ -416,3 +464,12 @@ class Port:
|
|||||||
raise Error('Trying to override data type for output port {} of operation {}: from {} to {}'.format(
|
raise Error('Trying to override data type for output port {} of operation {}: from {} to {}'.format(
|
||||||
self.idx, node.name, node._out_port_data_type[self.idx], data_type))
|
self.idx, node.name, node._out_port_data_type[self.idx], data_type))
|
||||||
node._out_port_data_type[self.idx] = data_type
|
node._out_port_data_type[self.idx] = data_type
|
||||||
|
|
||||||
|
def get_default_tensor_name(self):
|
||||||
|
"""
|
||||||
|
Gets default_tensor_name
|
||||||
|
:return: tensor name
|
||||||
|
"""
|
||||||
|
if self.type == 'in':
|
||||||
|
return None
|
||||||
|
return self.node.soft_get('name', self.node.id) + ":" + str(self.idx)
|
||||||
|
@ -142,6 +142,9 @@ def _fuse_mul(graph: Graph, node: Node, fuse_nodes: list, backward: bool = True)
|
|||||||
const_port.disconnect()
|
const_port.disconnect()
|
||||||
# as Mul node is added before convolution, output tensor from Convolution node
|
# as Mul node is added before convolution, output tensor from Convolution node
|
||||||
# corresponds to original Mul node
|
# corresponds to original Mul node
|
||||||
|
if producer_port.node.soft_get('type') == 'Parameter':
|
||||||
|
node.out_port(0).get_connection().set_source(producer_port, "source")
|
||||||
|
else:
|
||||||
node.out_port(0).get_connection().set_source(producer_port, "dest")
|
node.out_port(0).get_connection().set_source(producer_port, "dest")
|
||||||
|
|
||||||
return is_fused
|
return is_fused
|
||||||
|
@ -10,14 +10,17 @@ import networkx as nx
|
|||||||
|
|
||||||
from openvino.tools.mo.back.RemoveUselessConvert import RemoveUselessConvert
|
from openvino.tools.mo.back.RemoveUselessConvert import RemoveUselessConvert
|
||||||
from openvino.tools.mo.back.ResultRename import ResultRename
|
from openvino.tools.mo.back.ResultRename import ResultRename
|
||||||
|
from openvino.tools.mo.back.ie_ir_ver_2.emitter import port_renumber, serialize_constants, generate_ie_ir, \
|
||||||
|
serialize_mean_image
|
||||||
from openvino.tools.mo.back.op_versioning import OpVersioning
|
from openvino.tools.mo.back.op_versioning import OpVersioning
|
||||||
from openvino.tools.mo.ops.Cast import Cast
|
|
||||||
from openvino.tools.mo.back.ie_ir_ver_2.emitter import port_renumber, serialize_constants, generate_ie_ir, serialize_mean_image
|
|
||||||
from openvino.tools.mo.graph.graph import Node, Graph
|
from openvino.tools.mo.graph.graph import Node, Graph
|
||||||
from openvino.tools.mo.middle.passes import tensor_names, convert_data_type
|
from openvino.tools.mo.middle.passes import tensor_names, convert_data_type
|
||||||
from openvino.tools.mo.middle.passes.convert_data_type import data_type_str_to_np
|
from openvino.tools.mo.middle.passes.convert_data_type import data_type_str_to_np
|
||||||
|
from openvino.tools.mo.middle.passes.eliminate import shape_inference
|
||||||
from openvino.tools.mo.middle.passes.infer import type_infer
|
from openvino.tools.mo.middle.passes.infer import type_infer
|
||||||
from openvino.tools.mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
|
from openvino.tools.mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
|
||||||
|
from openvino.tools.mo.ops.Cast import Cast
|
||||||
|
from openvino.tools.mo.ops.op import Op
|
||||||
from openvino.tools.mo.utils.error import Error
|
from openvino.tools.mo.utils.error import Error
|
||||||
|
|
||||||
|
|
||||||
@ -171,6 +174,25 @@ def convert_inputs_of_specific_ops(graph: Graph):
|
|||||||
in_port.get_connection().insert_node(Cast(graph, {'dst_type': np_type}).create_node())
|
in_port.get_connection().insert_node(Cast(graph, {'dst_type': np_type}).create_node())
|
||||||
|
|
||||||
|
|
||||||
|
def set_default_tensor_names_for_parameters_results(graph: Graph):
|
||||||
|
for node in graph.get_op_nodes():
|
||||||
|
if node.soft_get('type') == 'Result' and node.is_in_port_connected(0):
|
||||||
|
port = node.in_port(0).get_connection().get_source()
|
||||||
|
elif node.soft_get('type') == 'Parameter' and node.is_out_port_connected(0):
|
||||||
|
port = node.out_port(0)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
if node.has_and_set('keep_output_port'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
tensors = port.get_tensor_names()
|
||||||
|
if tensors is not None and isinstance(tensors, list) and len(tensors) > 0:
|
||||||
|
continue
|
||||||
|
new_tensor_name = port.get_default_tensor_name()
|
||||||
|
op_name = port.node.soft_get('name')
|
||||||
|
port.add_tensor_names([op_name], [[new_tensor_name]])
|
||||||
|
|
||||||
|
|
||||||
def prepare_emit_ir(graph: Graph, data_type: str, output_dir: str, output_model_name: str,
|
def prepare_emit_ir(graph: Graph, data_type: str, output_dir: str, output_model_name: str,
|
||||||
mean_data: [list, None] = None, input_names: list = None, meta_info: dict = None,
|
mean_data: [list, None] = None, input_names: list = None, meta_info: dict = None,
|
||||||
use_temporary_path=False, convert_types=False):
|
use_temporary_path=False, convert_types=False):
|
||||||
@ -199,6 +221,7 @@ def prepare_emit_ir(graph: Graph, data_type: str, output_dir: str, output_model_
|
|||||||
for_graph_and_each_sub_graph_recursively(graph, RemoveUselessConvert().find_and_replace_pattern)
|
for_graph_and_each_sub_graph_recursively(graph, RemoveUselessConvert().find_and_replace_pattern)
|
||||||
|
|
||||||
ResultRename().find_and_replace_pattern(graph)
|
ResultRename().find_and_replace_pattern(graph)
|
||||||
|
set_default_tensor_names_for_parameters_results(graph)
|
||||||
|
|
||||||
for sub_graph in [graph] + collect_sub_graphs(graph):
|
for sub_graph in [graph] + collect_sub_graphs(graph):
|
||||||
op_order, data_order = determined_sort(get_sorted_outputs(sub_graph))
|
op_order, data_order = determined_sort(get_sorted_outputs(sub_graph))
|
||||||
|
@ -32,6 +32,12 @@ class TestsGetTensorNames(unittest.TestCase):
|
|||||||
op1_node.add_output_port(0)
|
op1_node.add_output_port(0)
|
||||||
self.assertTrue(op1_node.out_port(0).get_tensor_names() == [])
|
self.assertTrue(op1_node.out_port(0).get_tensor_names() == [])
|
||||||
|
|
||||||
|
input_node.out_port(0).add_tensor_names(["A", "B", "C"], [["A:0"], ["B:0", "B:1", "B:2"], ["C:0"]])
|
||||||
|
self.assertTrue(input_node.out_port(0).get_tensor_debug_info() ==
|
||||||
|
[('input', 'input'), ('Op1', 'Op1,Op2'), ("A", "A:0"), ("B", "B:0,B:1,B:2"), ("C", "C:0")])
|
||||||
|
self.assertTrue(input_node.out_port(0).get_tensor_names() ==
|
||||||
|
['A:0', 'B:0\\,B:1\\,B:2', 'C:0', 'Op1\\,Op2', 'input'])
|
||||||
|
|
||||||
def test_middle(self):
|
def test_middle(self):
|
||||||
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'),
|
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'),
|
||||||
('input_data', 'Op2')])
|
('input_data', 'Op2')])
|
||||||
@ -47,6 +53,12 @@ class TestsGetTensorNames(unittest.TestCase):
|
|||||||
op2_node.add_output_port(0)
|
op2_node.add_output_port(0)
|
||||||
self.assertTrue(op2_node.out_port(0).get_tensor_names() == [])
|
self.assertTrue(op2_node.out_port(0).get_tensor_names() == [])
|
||||||
|
|
||||||
|
input_node.out_port(0).add_tensor_names(["A", "B", "C"], [["A:0"], ["B:0", "B:1", "B:2"], ["C:0"]])
|
||||||
|
self.assertTrue(input_node.out_port(0).get_tensor_debug_info() ==
|
||||||
|
[('input', 'input'), ('Op1', 'Op1,Op2'), ("A", "A:0"), ("B", "B:0,B:1,B:2"), ("C", "C:0")])
|
||||||
|
self.assertTrue(input_node.out_port(0).get_tensor_names() ==
|
||||||
|
['A:0', 'B:0\\,B:1\\,B:2', 'C:0', 'Op1\\,Op2', 'input'])
|
||||||
|
|
||||||
def test_port_renumber(self):
|
def test_port_renumber(self):
|
||||||
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'),
|
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'),
|
||||||
('Op1', 'Op1_data', {'out': 1}), ('Op1_data', 'Op2')])
|
('Op1', 'Op1_data', {'out': 1}), ('Op1_data', 'Op2')])
|
||||||
@ -58,6 +70,13 @@ class TestsGetTensorNames(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertTrue(op1_node.out_port(0).get_tensor_names(port_renumber=True) == ['Op1\\,Op2'])
|
self.assertTrue(op1_node.out_port(0).get_tensor_names(port_renumber=True) == ['Op1\\,Op2'])
|
||||||
|
|
||||||
|
input_node.out_port(0).add_tensor_names(["A", "B", "C"], [["A:0"], ["B:0", "B:1", "B:2"], ["C:0"]],
|
||||||
|
port_renumber=True)
|
||||||
|
self.assertTrue(input_node.out_port(0).get_tensor_debug_info(port_renumber=True) ==
|
||||||
|
[('input', 'input'), ('Op1', 'Op1,Op2'), ("A", "A:0"), ("B", "B:0,B:1,B:2"), ("C", "C:0")])
|
||||||
|
self.assertTrue(input_node.out_port(0).get_tensor_names(port_renumber=True) ==
|
||||||
|
['A:0', 'B:0\\,B:1\\,B:2', 'C:0', 'Op1\\,Op2', 'input'])
|
||||||
|
|
||||||
def test_reconnect_middle_case1(self):
|
def test_reconnect_middle_case1(self):
|
||||||
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'), ('Op3', 'Op3_data')])
|
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'), ('Op3', 'Op3_data')])
|
||||||
input_node = Node(graph, 'input')
|
input_node = Node(graph, 'input')
|
||||||
@ -65,11 +84,24 @@ class TestsGetTensorNames(unittest.TestCase):
|
|||||||
input_node_out_port = input_node.out_port(0)
|
input_node_out_port = input_node.out_port(0)
|
||||||
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
|
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
|
||||||
|
|
||||||
|
op3_node = Node(graph, 'Op3')
|
||||||
|
input_node_out_port.get_connection().set_source(op3_node.out_port(0), "merge")
|
||||||
|
|
||||||
|
self.assertTrue(input_node_out_port.get_tensor_names() is None)
|
||||||
|
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op1\\,Op2', 'Op3', 'input'])
|
||||||
|
|
||||||
|
def test_reconnect_middle_case1_parameter(self):
|
||||||
|
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'), ('Op3', 'Op3_data')])
|
||||||
|
input_node = Node(graph, 'input')
|
||||||
|
|
||||||
|
input_node_out_port = input_node.out_port(0)
|
||||||
|
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
|
||||||
|
|
||||||
op3_node = Node(graph, 'Op3')
|
op3_node = Node(graph, 'Op3')
|
||||||
input_node_out_port.get_connection().set_source(op3_node.out_port(0))
|
input_node_out_port.get_connection().set_source(op3_node.out_port(0))
|
||||||
|
|
||||||
self.assertTrue(input_node_out_port.get_tensor_names() is None)
|
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
|
||||||
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op1\\,Op2', 'Op3', 'input'])
|
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op3'])
|
||||||
|
|
||||||
def test_reconnect_front_case1(self):
|
def test_reconnect_front_case1(self):
|
||||||
graph = build_graph(nodes, [('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input'),
|
graph = build_graph(nodes, [('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input'),
|
||||||
@ -81,11 +113,27 @@ class TestsGetTensorNames(unittest.TestCase):
|
|||||||
input_node_out_port = input_node.out_port(0)
|
input_node_out_port = input_node.out_port(0)
|
||||||
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
|
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
|
||||||
|
|
||||||
|
op3_node = Node(graph, 'Op3')
|
||||||
|
input_node_out_port.get_connection().set_source(op3_node.out_port(0), "merge")
|
||||||
|
|
||||||
|
self.assertTrue(input_node_out_port.get_tensor_names() == [])
|
||||||
|
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op1\\,Op2', 'Op3', 'input'])
|
||||||
|
|
||||||
|
def test_reconnect_front_case1_parameter(self):
|
||||||
|
graph = build_graph(nodes, [('input', 'Op1', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('input', 'input'),
|
||||||
|
('Op1', 'Op1,Op2')]}),
|
||||||
|
('Op3', 'Op2', {'in': 0, 'out': 0, 'fw_tensor_debug_info': [('Op3', 'Op3')]})])
|
||||||
|
graph.stage = 'front'
|
||||||
|
input_node = Node(graph, 'input')
|
||||||
|
|
||||||
|
input_node_out_port = input_node.out_port(0)
|
||||||
|
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
|
||||||
|
|
||||||
op3_node = Node(graph, 'Op3')
|
op3_node = Node(graph, 'Op3')
|
||||||
input_node_out_port.get_connection().set_source(op3_node.out_port(0))
|
input_node_out_port.get_connection().set_source(op3_node.out_port(0))
|
||||||
|
|
||||||
self.assertTrue(input_node_out_port.get_tensor_names() == [])
|
self.assertTrue(input_node_out_port.get_tensor_names() == [])
|
||||||
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op1\\,Op2', 'Op3', 'input'])
|
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op3'])
|
||||||
|
|
||||||
def test_reconnect_middle_case1(self):
|
def test_reconnect_middle_case1(self):
|
||||||
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'), ('Op3', 'Op3_data')])
|
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'), ('Op3', 'Op3_data')])
|
||||||
@ -94,9 +142,36 @@ class TestsGetTensorNames(unittest.TestCase):
|
|||||||
input_node_out_port = input_node.out_port(0)
|
input_node_out_port = input_node.out_port(0)
|
||||||
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
|
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
|
||||||
|
|
||||||
|
op3_node = Node(graph, 'Op3')
|
||||||
|
input_node_out_port.get_connection().set_source(op3_node.out_port(0), "merge")
|
||||||
|
|
||||||
|
self.assertTrue(input_node_out_port.get_tensor_names() == [])
|
||||||
|
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op1\\,Op2', 'Op3', 'input'])
|
||||||
|
|
||||||
|
def test_reconnect_middle_case1_parameter(self):
|
||||||
|
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'), ('Op3', 'Op3_data')])
|
||||||
|
input_node = Node(graph, 'input')
|
||||||
|
|
||||||
|
input_node_out_port = input_node.out_port(0)
|
||||||
|
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
|
||||||
|
|
||||||
op3_node = Node(graph, 'Op3')
|
op3_node = Node(graph, 'Op3')
|
||||||
input_node_out_port.get_connection().set_source(op3_node.out_port(0))
|
input_node_out_port.get_connection().set_source(op3_node.out_port(0))
|
||||||
|
|
||||||
|
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
|
||||||
|
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op3'])
|
||||||
|
|
||||||
|
def test_reconnect_middle_case2(self):
|
||||||
|
graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1', {'out': 0}),
|
||||||
|
('input_data', 'Op1', {'out': 1}), ('Op3', 'Op3_data')])
|
||||||
|
input_node = Node(graph, 'input')
|
||||||
|
|
||||||
|
input_node_out_port = input_node.out_port(0)
|
||||||
|
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
|
||||||
|
|
||||||
|
op3_node = Node(graph, 'Op3')
|
||||||
|
input_node_out_port.get_connection().set_source(op3_node.out_port(0), "merge")
|
||||||
|
|
||||||
self.assertTrue(input_node_out_port.get_tensor_names() == [])
|
self.assertTrue(input_node_out_port.get_tensor_names() == [])
|
||||||
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op1\\,Op2', 'Op3', 'input'])
|
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op1\\,Op2', 'Op3', 'input'])
|
||||||
|
|
||||||
@ -111,8 +186,8 @@ class TestsGetTensorNames(unittest.TestCase):
|
|||||||
op3_node = Node(graph, 'Op3')
|
op3_node = Node(graph, 'Op3')
|
||||||
input_node_out_port.get_connection().set_source(op3_node.out_port(0))
|
input_node_out_port.get_connection().set_source(op3_node.out_port(0))
|
||||||
|
|
||||||
self.assertTrue(input_node_out_port.get_tensor_names() == [])
|
self.assertTrue(input_node_out_port.get_tensor_names() == ['Op1\\,Op2', 'input'])
|
||||||
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op1\\,Op2', 'Op3', 'input'])
|
self.assertTrue(op3_node.out_port(0).get_tensor_names() == ['Op3'])
|
||||||
|
|
||||||
|
|
||||||
class TestPortMethods(unittest.TestCase):
|
class TestPortMethods(unittest.TestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user