Change sort functions to fix incorrect order in port dictionaries (#7283)
* Change sort functions to fix incorrect order * Add separate sorts for different cases * Update sort function to correct control flow edges handling * Refactor all sorts for dictionary ordering * Update build_graph function for correct control_flow edges creation * Add tests for nodes and ports sort equality * Fix wrong port value * Add direct conversion from bool to int * Use .replace instead of .strip * Add test with both control_flow and ordinary edges, refactored edges lists filling * Refactored nodes dict filling * Delete unused test code
This commit is contained in:
committed by
GitHub
parent
168d8b9e84
commit
9cfdad9afc
@@ -874,7 +874,7 @@ class ObjectDetectionAPIPreprocessor2Replacement(FrontReplacementFromConfigFileG
|
||||
# replace sub-graph between start and end nodes (including them) with new_preprocessing_ops nodes
|
||||
end_node.out_port(0).get_connection().set_source(new_preprocessing_ops[-1].out_port(0))
|
||||
start_node.in_port(0).get_connection().set_destination(
|
||||
new_preprocessing_ops[0].in_port(new_preprocessing_ops[0].is_in_port_connected(0)))
|
||||
new_preprocessing_ops[0].in_port(int(new_preprocessing_ops[0].is_in_port_connected(0))))
|
||||
else:
|
||||
if trailing: # case 2
|
||||
# change output of the end_node to be produced with the start node producer
|
||||
|
||||
@@ -147,7 +147,7 @@ class Node:
|
||||
for idx in self._in_ports:
|
||||
if control_flow or 'control_flow' not in self._in_ports[idx] or not self._in_ports[idx]['control_flow']:
|
||||
ports.update({idx: self.in_port(idx, control_flow=control_flow)})
|
||||
return dict_to_ordered_dict(ports, func=lambda t: str(t))
|
||||
return dict_to_ordered_dict(ports, func=lambda t: int(str(t).replace('control_flow_', '')))
|
||||
|
||||
def out_port(self, idx=None, control_flow=False) -> Port:
|
||||
if not self.has_valid('_out_ports'):
|
||||
@@ -165,7 +165,7 @@ class Node:
|
||||
for idx in self._out_ports:
|
||||
if control_flow or 'control_flow' not in self._out_ports[idx] or not self._out_ports[idx]['control_flow']:
|
||||
ports.update({idx: self.out_port(idx, control_flow=control_flow)})
|
||||
return dict_to_ordered_dict(ports, func=lambda t: str(t))
|
||||
return dict_to_ordered_dict(ports, func=lambda t: int(str(t).replace('control_flow_', '')))
|
||||
|
||||
def has_port(self, port_type, idx, control_flow=False):
|
||||
assert port_type in ['in', 'out'], "Invalid usage of has_port method"
|
||||
@@ -195,12 +195,14 @@ class Node:
|
||||
|
||||
def in_nodes_edges(self, control_flow: bool = False):
|
||||
return dict_to_ordered_dict({x[1]['in']: (Node(self.graph, x[0]), x[1]) for x in
|
||||
self.get_inputs(control_flow=control_flow)})
|
||||
self.get_inputs(control_flow=control_flow)},
|
||||
func=lambda t: int(str(t).replace('control_flow_', '')))
|
||||
|
||||
def in_nodes(self, control_flow: bool = False):
|
||||
if self.kind == 'op':
|
||||
return dict_to_ordered_dict({x[1]['in']: Node(self.graph, x[0]) for x in
|
||||
self.get_inputs(control_flow=control_flow)})
|
||||
self.get_inputs(control_flow=control_flow)},
|
||||
func=lambda t: int(str(t).replace('control_flow_', '')))
|
||||
elif self.kind == 'data':
|
||||
return [Node(self.graph, n) for n, d in self.get_inputs(control_flow=control_flow)]
|
||||
|
||||
@@ -211,20 +213,23 @@ class Node:
|
||||
assert self.has('kind')
|
||||
assert self.kind in ['op', 'data']
|
||||
if self.kind == 'op':
|
||||
return dict_to_ordered_dict({x[1]['in']: x[1] for x in self.get_inputs(control_flow=control_flow)})
|
||||
return dict_to_ordered_dict({x[1]['in']: x[1] for x in self.get_inputs(control_flow=control_flow)},
|
||||
func=lambda t: int(str(t).replace('control_flow_', '')))
|
||||
elif self.kind == 'data':
|
||||
return [d for n, d in self.get_inputs(control_flow=control_flow)]
|
||||
|
||||
def out_nodes_edges(self, control_flow: bool = False):
|
||||
return dict_to_ordered_dict({x[1]['out']: (Node(self.graph, x[0]), x[1]) for x in
|
||||
self.get_outputs(control_flow=control_flow)})
|
||||
self.get_outputs(control_flow=control_flow)},
|
||||
func=lambda t: int(str(t).replace('control_flow_', '')))
|
||||
|
||||
def out_nodes(self, control_flow: bool = False):
|
||||
assert self.has('kind')
|
||||
assert self.kind in ['op', 'data']
|
||||
if self.kind == 'op':
|
||||
return dict_to_ordered_dict({x[1]['out']: Node(self.graph, x[0]) for x in
|
||||
self.get_outputs(control_flow=control_flow)})
|
||||
self.get_outputs(control_flow=control_flow)},
|
||||
func=lambda t: int(str(t).replace('control_flow_', '')))
|
||||
elif self.kind == 'data':
|
||||
return [Node(self.graph, n) for n, d in self.get_outputs(control_flow=control_flow)]
|
||||
|
||||
@@ -232,7 +237,8 @@ class Node:
|
||||
assert self.has('kind')
|
||||
assert self.kind in ['op', 'data']
|
||||
if self.kind == 'op':
|
||||
return dict_to_ordered_dict({x[1]['out']: x[1] for x in self.get_outputs(control_flow=control_flow)})
|
||||
return dict_to_ordered_dict({x[1]['out']: x[1] for x in self.get_outputs(control_flow=control_flow)},
|
||||
func=lambda t: int(str(t).replace('control_flow_', '')))
|
||||
elif self.kind == 'data':
|
||||
return [d for n, d in self.get_outputs(control_flow=control_flow)]
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from mo.graph.graph import Node, Graph, add_opoutput, dict_includes_compare_attr
|
||||
from mo.ops.const import Const
|
||||
from mo.utils.error import Error
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph
|
||||
from unit_tests.utils.graph import build_graph, build_graph_with_edge_attrs
|
||||
|
||||
nodes = {
|
||||
'0': {'name': 'input1', 'type': 'Identity', 'value': None, 'kind': 'op', 'op': 'Parameter'},
|
||||
@@ -429,6 +429,20 @@ class TestNewGraphAPIMiddle(unittest.TestCase):
|
||||
'const_1_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
}
|
||||
|
||||
nodes_10_in_10_out = {
|
||||
'op_concat': {'type': 'Concat', 'value': None, 'kind': 'op', 'op': 'Concat'},
|
||||
'op_concat_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
|
||||
'op_split': {'type': 'Split', 'value': None, 'kind': 'op', 'op': 'Split'},
|
||||
}
|
||||
|
||||
# Filling nodes list
|
||||
for idx in range(11):
|
||||
nodes_10_in_10_out.update({'in_{}'.format(idx): {'type': 'Parameter', 'value': None, 'kind': 'op', 'op': 'Parameter'}})
|
||||
nodes_10_in_10_out.update({'in_{}_data'.format(idx): {'value': None, 'shape': None, 'kind': 'data'}})
|
||||
nodes_10_in_10_out.update({'out_{}'.format(idx): {'type': 'Parameter', 'value': None, 'kind': 'op', 'op': 'Parameter'}})
|
||||
nodes_10_in_10_out.update({'op_split_{}_data'.format(idx): {'value': None, 'shape': None, 'kind': 'data'}})
|
||||
|
||||
###########################################
|
||||
###### TESTS FOR PORT CLASS METHODS #######
|
||||
###########################################
|
||||
@@ -1083,6 +1097,113 @@ class TestNewGraphAPIMiddle(unittest.TestCase):
|
||||
for idx in range(len(node.out_ports())):
|
||||
self.assertEqual(node.out_port(idx), node.out_ports()[idx])
|
||||
|
||||
def test_node_in_ports_order_10_inputs(self):
|
||||
edges = [('op_concat', 'op_concat_data'),
|
||||
('op_concat_data', 'op_split'),
|
||||
]
|
||||
|
||||
# Filling edges list
|
||||
for idx in range(11):
|
||||
edges.append(('in_{}'.format(idx), 'in_{}_data'.format(idx)))
|
||||
edges.append(('in_{}_data'.format(idx), 'op_concat', {'in': idx}))
|
||||
edges.append(('op_split', 'op_split_{}_data'.format(idx), {'out': idx}))
|
||||
edges.append(('op_split_{}_data'.format(idx), 'out_{}'.format(idx)))
|
||||
|
||||
graph = build_graph(self.nodes_10_in_10_out, edges)
|
||||
|
||||
node_concat = Node(graph, 'op_concat')
|
||||
node_split = Node(graph, 'op_split')
|
||||
|
||||
self.assertEqual(len(node_concat.in_ports()), len(node_concat.in_nodes()))
|
||||
|
||||
l1 = [node_concat.in_port(idx).get_source().node.name for idx in node_concat.in_ports()]
|
||||
l2 = [node_concat.in_node(idx).in_node(0).name for idx in node_concat.in_nodes()]
|
||||
|
||||
self.assertEqual(l1, l2)
|
||||
|
||||
l1 = [node_split.out_port(idx).get_destination().node.name for idx in node_split.out_ports()]
|
||||
l2 = [node_split.out_node(idx).out_node(0).name for idx in node_split.out_nodes()]
|
||||
|
||||
self.assertEqual(l1, l2)
|
||||
|
||||
def test_node_in_ports_order_10_inputs_control_flow(self):
|
||||
edges = [('op_concat', 'op_concat_data', {'out': 'control_flow_0', 'control_flow_edge': True}),
|
||||
('op_concat_data', 'op_split', {'in': 'control_flow_0', 'control_flow_edge': True}),
|
||||
]
|
||||
|
||||
# Filling edges list
|
||||
for idx in range(11):
|
||||
edges.append(('in_{}'.format(idx), 'in_{}_data'.format(idx),
|
||||
{'out': 'control_flow_0', 'control_flow_edge': True}))
|
||||
edges.append(('in_{}_data'.format(idx), 'op_concat',
|
||||
{'in': 'control_flow_{}'.format(idx), 'control_flow_edge': True}))
|
||||
edges.append(('op_split', 'op_split_{}_data'.format(idx),
|
||||
{'out': 'control_flow_{}'.format(idx), 'control_flow_edge': True}))
|
||||
edges.append(('op_split_{}_data'.format(idx), 'out_{}'.format(idx),
|
||||
{'in': 'control_flow_0', 'control_flow_edge': True}))
|
||||
|
||||
graph = build_graph(self.nodes_10_in_10_out, edges)
|
||||
|
||||
node_concat = Node(graph, 'op_concat')
|
||||
node_split = Node(graph, 'op_split')
|
||||
|
||||
self.assertEqual(len(node_concat.in_ports()), len(node_concat.in_nodes()))
|
||||
|
||||
l1 = [node_concat.in_port(idx, control_flow=True).get_source().node.name
|
||||
for idx in node_concat.in_ports(control_flow=True)]
|
||||
l2 = [node_concat.in_node(idx, control_flow=True).in_node(0, control_flow=True).name
|
||||
for idx in node_concat.in_nodes(control_flow=True)]
|
||||
|
||||
self.assertEqual(l1, l2)
|
||||
|
||||
l1 = [node_split.out_port(idx, control_flow=True).get_destination().node.name
|
||||
for idx in node_split.out_ports(control_flow=True)]
|
||||
l2 = [node_split.out_node(idx, control_flow=True).out_node(0, control_flow=True).name for idx in
|
||||
node_split.out_nodes(control_flow=True)]
|
||||
|
||||
self.assertEqual(l1, l2)
|
||||
|
||||
def test_node_in_ports_order_10_inputs_mixed(self):
|
||||
edges = [('op_concat', 'op_concat_data', {'out': 'control_flow_0', 'control_flow_edge': True}),
|
||||
('op_concat_data', 'op_split', {'in': 'control_flow_0', 'control_flow_edge': True}),
|
||||
]
|
||||
graph = build_graph(self.nodes_10_in_10_out, edges)
|
||||
|
||||
# Filling edges list
|
||||
for idx in range(5):
|
||||
edges.append(('in_{}'.format(idx), 'in_{}_data'.format(idx)))
|
||||
edges.append(('in_{}_data'.format(idx), 'op_concat'))
|
||||
edges.append(('op_split', 'op_split_{}_data'.format(idx)))
|
||||
edges.append(('op_split_{}_data'.format(idx), 'out_{}'.format(idx)))
|
||||
for idx in range(5, 11):
|
||||
edges.append(('in_{}'.format(idx), 'in_{}_data'.format(idx),
|
||||
{'out': 'control_flow_0', 'control_flow_edge': True}))
|
||||
edges.append(('in_{}_data'.format(idx), 'op_concat',
|
||||
{'in': 'control_flow_{}', 'control_flow_edge': True}))
|
||||
edges.append(('op_split', 'op_split_{}_data'.format(idx),
|
||||
{'out': 'control_flow_{}', 'control_flow_edge': True}))
|
||||
edges.append(('op_split_{}_data'.format(idx), 'out_{}'.format(idx),
|
||||
{'in': 'control_flow_0', 'control_flow_edge': True}))
|
||||
|
||||
node_concat = Node(graph, 'op_concat')
|
||||
node_split = Node(graph, 'op_split')
|
||||
|
||||
self.assertEqual(len(node_concat.in_ports()), len(node_concat.in_nodes()))
|
||||
|
||||
l1 = [node_concat.in_port(idx, control_flow=True).get_source().node.name
|
||||
for idx in node_concat.in_ports(control_flow=True)]
|
||||
l2 = [node_concat.in_node(idx, control_flow=True).in_node(0, control_flow=True).name
|
||||
for idx in node_concat.in_nodes(control_flow=True)]
|
||||
|
||||
self.assertEqual(l1, l2)
|
||||
|
||||
l1 = [node_split.out_port(idx, control_flow=True).get_destination().node.name
|
||||
for idx in node_split.out_ports(control_flow=True)]
|
||||
l2 = [node_split.out_node(idx, control_flow=True).out_node(0, control_flow=True).name for idx in
|
||||
node_split.out_nodes(control_flow=True)]
|
||||
|
||||
self.assertEqual(l1, l2)
|
||||
|
||||
|
||||
class TestNewGraphAPIFront(unittest.TestCase):
|
||||
nodes = {
|
||||
|
||||
@@ -190,14 +190,16 @@ def build_graph(nodes_attrs: dict, edges: list, update_attributes: dict = None,
|
||||
|
||||
for node in graph.get_op_nodes():
|
||||
# Add in_ports attribute
|
||||
in_edges = node.in_edges()
|
||||
in_edges = node.in_edges(control_flow=True)
|
||||
for attr in in_edges.values():
|
||||
node.add_input_port(idx=attr['in'])
|
||||
control_flow = True if 'control_flow_edge' in attr and attr['control_flow_edge'] is True else False
|
||||
node.add_input_port(idx=attr['in'], control_flow=control_flow)
|
||||
|
||||
# Add out_ports attribute
|
||||
out_edges = node.out_edges()
|
||||
out_edges = node.out_edges(control_flow=True)
|
||||
for attr in out_edges.values():
|
||||
node.add_output_port(idx=attr['out'])
|
||||
control_flow = True if 'control_flow_edge' in attr and attr['control_flow_edge'] is True else False
|
||||
node.add_output_port(idx=attr['out'], control_flow=control_flow)
|
||||
|
||||
graph.graph['cmd_params'] = cli
|
||||
return graph
|
||||
|
||||
Reference in New Issue
Block a user