Change sort functions to fix incorrect order in port dictionaries (#7283)

* Change sort functions to fix incorrect order

* Add separate sorts for different cases

* Update sort function to correct control flow edges handling

* Refactor all sorts for dictionary ordering

* Update build_graph function for correct control_flow edges creation

* Add tests for nodes and ports sort equality

* Fix wrong port value

* Add direct conversion from bool to int

* Use .replace instead of .strip

* Add test with both control_flow and ordinary edges, refactored edges lists filling

* Refactored nodes dict filling

* Delete unused test code
This commit is contained in:
Anton Chetverikov
2021-09-20 10:17:44 +03:00
committed by GitHub
parent 168d8b9e84
commit 9cfdad9afc
4 changed files with 143 additions and 14 deletions

View File

@@ -874,7 +874,7 @@ class ObjectDetectionAPIPreprocessor2Replacement(FrontReplacementFromConfigFileG
# replace sub-graph between start and end nodes (including them) with new_preprocessing_ops nodes
end_node.out_port(0).get_connection().set_source(new_preprocessing_ops[-1].out_port(0))
start_node.in_port(0).get_connection().set_destination(
new_preprocessing_ops[0].in_port(new_preprocessing_ops[0].is_in_port_connected(0)))
new_preprocessing_ops[0].in_port(int(new_preprocessing_ops[0].is_in_port_connected(0))))
else:
if trailing: # case 2
# change output of the end_node to be produced with the start node producer

View File

@@ -147,7 +147,7 @@ class Node:
for idx in self._in_ports:
if control_flow or 'control_flow' not in self._in_ports[idx] or not self._in_ports[idx]['control_flow']:
ports.update({idx: self.in_port(idx, control_flow=control_flow)})
return dict_to_ordered_dict(ports, func=lambda t: str(t))
return dict_to_ordered_dict(ports, func=lambda t: int(str(t).replace('control_flow_', '')))
def out_port(self, idx=None, control_flow=False) -> Port:
if not self.has_valid('_out_ports'):
@@ -165,7 +165,7 @@ class Node:
for idx in self._out_ports:
if control_flow or 'control_flow' not in self._out_ports[idx] or not self._out_ports[idx]['control_flow']:
ports.update({idx: self.out_port(idx, control_flow=control_flow)})
return dict_to_ordered_dict(ports, func=lambda t: str(t))
return dict_to_ordered_dict(ports, func=lambda t: int(str(t).replace('control_flow_', '')))
def has_port(self, port_type, idx, control_flow=False):
assert port_type in ['in', 'out'], "Invalid usage of has_port method"
@@ -195,12 +195,14 @@ class Node:
def in_nodes_edges(self, control_flow: bool = False):
return dict_to_ordered_dict({x[1]['in']: (Node(self.graph, x[0]), x[1]) for x in
self.get_inputs(control_flow=control_flow)})
self.get_inputs(control_flow=control_flow)},
func=lambda t: int(str(t).replace('control_flow_', '')))
def in_nodes(self, control_flow: bool = False):
if self.kind == 'op':
return dict_to_ordered_dict({x[1]['in']: Node(self.graph, x[0]) for x in
self.get_inputs(control_flow=control_flow)})
self.get_inputs(control_flow=control_flow)},
func=lambda t: int(str(t).replace('control_flow_', '')))
elif self.kind == 'data':
return [Node(self.graph, n) for n, d in self.get_inputs(control_flow=control_flow)]
@@ -211,20 +213,23 @@ class Node:
assert self.has('kind')
assert self.kind in ['op', 'data']
if self.kind == 'op':
return dict_to_ordered_dict({x[1]['in']: x[1] for x in self.get_inputs(control_flow=control_flow)})
return dict_to_ordered_dict({x[1]['in']: x[1] for x in self.get_inputs(control_flow=control_flow)},
func=lambda t: int(str(t).replace('control_flow_', '')))
elif self.kind == 'data':
return [d for n, d in self.get_inputs(control_flow=control_flow)]
def out_nodes_edges(self, control_flow: bool = False):
return dict_to_ordered_dict({x[1]['out']: (Node(self.graph, x[0]), x[1]) for x in
self.get_outputs(control_flow=control_flow)})
self.get_outputs(control_flow=control_flow)},
func=lambda t: int(str(t).replace('control_flow_', '')))
def out_nodes(self, control_flow: bool = False):
assert self.has('kind')
assert self.kind in ['op', 'data']
if self.kind == 'op':
return dict_to_ordered_dict({x[1]['out']: Node(self.graph, x[0]) for x in
self.get_outputs(control_flow=control_flow)})
self.get_outputs(control_flow=control_flow)},
func=lambda t: int(str(t).replace('control_flow_', '')))
elif self.kind == 'data':
return [Node(self.graph, n) for n, d in self.get_outputs(control_flow=control_flow)]
@@ -232,7 +237,8 @@ class Node:
assert self.has('kind')
assert self.kind in ['op', 'data']
if self.kind == 'op':
return dict_to_ordered_dict({x[1]['out']: x[1] for x in self.get_outputs(control_flow=control_flow)})
return dict_to_ordered_dict({x[1]['out']: x[1] for x in self.get_outputs(control_flow=control_flow)},
func=lambda t: int(str(t).replace('control_flow_', '')))
elif self.kind == 'data':
return [d for n, d in self.get_outputs(control_flow=control_flow)]

View File

@@ -11,7 +11,7 @@ from mo.graph.graph import Node, Graph, add_opoutput, dict_includes_compare_attr
from mo.ops.const import Const
from mo.utils.error import Error
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph
from unit_tests.utils.graph import build_graph, build_graph_with_edge_attrs
nodes = {
'0': {'name': 'input1', 'type': 'Identity', 'value': None, 'kind': 'op', 'op': 'Parameter'},
@@ -429,6 +429,20 @@ class TestNewGraphAPIMiddle(unittest.TestCase):
'const_1_data': {'value': None, 'shape': None, 'kind': 'data'},
}
nodes_10_in_10_out = {
'op_concat': {'type': 'Concat', 'value': None, 'kind': 'op', 'op': 'Concat'},
'op_concat_data': {'value': None, 'shape': None, 'kind': 'data'},
'op_split': {'type': 'Split', 'value': None, 'kind': 'op', 'op': 'Split'},
}
# Filling nodes list
for idx in range(11):
nodes_10_in_10_out.update({'in_{}'.format(idx): {'type': 'Parameter', 'value': None, 'kind': 'op', 'op': 'Parameter'}})
nodes_10_in_10_out.update({'in_{}_data'.format(idx): {'value': None, 'shape': None, 'kind': 'data'}})
nodes_10_in_10_out.update({'out_{}'.format(idx): {'type': 'Parameter', 'value': None, 'kind': 'op', 'op': 'Parameter'}})
nodes_10_in_10_out.update({'op_split_{}_data'.format(idx): {'value': None, 'shape': None, 'kind': 'data'}})
###########################################
###### TESTS FOR PORT CLASS METHODS #######
###########################################
@@ -1083,6 +1097,113 @@ class TestNewGraphAPIMiddle(unittest.TestCase):
for idx in range(len(node.out_ports())):
self.assertEqual(node.out_port(idx), node.out_ports()[idx])
def test_node_in_ports_order_10_inputs(self):
edges = [('op_concat', 'op_concat_data'),
('op_concat_data', 'op_split'),
]
# Filling edges list
for idx in range(11):
edges.append(('in_{}'.format(idx), 'in_{}_data'.format(idx)))
edges.append(('in_{}_data'.format(idx), 'op_concat', {'in': idx}))
edges.append(('op_split', 'op_split_{}_data'.format(idx), {'out': idx}))
edges.append(('op_split_{}_data'.format(idx), 'out_{}'.format(idx)))
graph = build_graph(self.nodes_10_in_10_out, edges)
node_concat = Node(graph, 'op_concat')
node_split = Node(graph, 'op_split')
self.assertEqual(len(node_concat.in_ports()), len(node_concat.in_nodes()))
l1 = [node_concat.in_port(idx).get_source().node.name for idx in node_concat.in_ports()]
l2 = [node_concat.in_node(idx).in_node(0).name for idx in node_concat.in_nodes()]
self.assertEqual(l1, l2)
l1 = [node_split.out_port(idx).get_destination().node.name for idx in node_split.out_ports()]
l2 = [node_split.out_node(idx).out_node(0).name for idx in node_split.out_nodes()]
self.assertEqual(l1, l2)
def test_node_in_ports_order_10_inputs_control_flow(self):
edges = [('op_concat', 'op_concat_data', {'out': 'control_flow_0', 'control_flow_edge': True}),
('op_concat_data', 'op_split', {'in': 'control_flow_0', 'control_flow_edge': True}),
]
# Filling edges list
for idx in range(11):
edges.append(('in_{}'.format(idx), 'in_{}_data'.format(idx),
{'out': 'control_flow_0', 'control_flow_edge': True}))
edges.append(('in_{}_data'.format(idx), 'op_concat',
{'in': 'control_flow_{}'.format(idx), 'control_flow_edge': True}))
edges.append(('op_split', 'op_split_{}_data'.format(idx),
{'out': 'control_flow_{}'.format(idx), 'control_flow_edge': True}))
edges.append(('op_split_{}_data'.format(idx), 'out_{}'.format(idx),
{'in': 'control_flow_0', 'control_flow_edge': True}))
graph = build_graph(self.nodes_10_in_10_out, edges)
node_concat = Node(graph, 'op_concat')
node_split = Node(graph, 'op_split')
self.assertEqual(len(node_concat.in_ports()), len(node_concat.in_nodes()))
l1 = [node_concat.in_port(idx, control_flow=True).get_source().node.name
for idx in node_concat.in_ports(control_flow=True)]
l2 = [node_concat.in_node(idx, control_flow=True).in_node(0, control_flow=True).name
for idx in node_concat.in_nodes(control_flow=True)]
self.assertEqual(l1, l2)
l1 = [node_split.out_port(idx, control_flow=True).get_destination().node.name
for idx in node_split.out_ports(control_flow=True)]
l2 = [node_split.out_node(idx, control_flow=True).out_node(0, control_flow=True).name for idx in
node_split.out_nodes(control_flow=True)]
self.assertEqual(l1, l2)
def test_node_in_ports_order_10_inputs_mixed(self):
edges = [('op_concat', 'op_concat_data', {'out': 'control_flow_0', 'control_flow_edge': True}),
('op_concat_data', 'op_split', {'in': 'control_flow_0', 'control_flow_edge': True}),
]
graph = build_graph(self.nodes_10_in_10_out, edges)
# Filling edges list
for idx in range(5):
edges.append(('in_{}'.format(idx), 'in_{}_data'.format(idx)))
edges.append(('in_{}_data'.format(idx), 'op_concat'))
edges.append(('op_split', 'op_split_{}_data'.format(idx)))
edges.append(('op_split_{}_data'.format(idx), 'out_{}'.format(idx)))
for idx in range(5, 11):
edges.append(('in_{}'.format(idx), 'in_{}_data'.format(idx),
{'out': 'control_flow_0', 'control_flow_edge': True}))
edges.append(('in_{}_data'.format(idx), 'op_concat',
{'in': 'control_flow_{}', 'control_flow_edge': True}))
edges.append(('op_split', 'op_split_{}_data'.format(idx),
{'out': 'control_flow_{}', 'control_flow_edge': True}))
edges.append(('op_split_{}_data'.format(idx), 'out_{}'.format(idx),
{'in': 'control_flow_0', 'control_flow_edge': True}))
node_concat = Node(graph, 'op_concat')
node_split = Node(graph, 'op_split')
self.assertEqual(len(node_concat.in_ports()), len(node_concat.in_nodes()))
l1 = [node_concat.in_port(idx, control_flow=True).get_source().node.name
for idx in node_concat.in_ports(control_flow=True)]
l2 = [node_concat.in_node(idx, control_flow=True).in_node(0, control_flow=True).name
for idx in node_concat.in_nodes(control_flow=True)]
self.assertEqual(l1, l2)
l1 = [node_split.out_port(idx, control_flow=True).get_destination().node.name
for idx in node_split.out_ports(control_flow=True)]
l2 = [node_split.out_node(idx, control_flow=True).out_node(0, control_flow=True).name for idx in
node_split.out_nodes(control_flow=True)]
self.assertEqual(l1, l2)
class TestNewGraphAPIFront(unittest.TestCase):
nodes = {

View File

@@ -190,14 +190,16 @@ def build_graph(nodes_attrs: dict, edges: list, update_attributes: dict = None,
for node in graph.get_op_nodes():
# Add in_ports attribute
in_edges = node.in_edges()
in_edges = node.in_edges(control_flow=True)
for attr in in_edges.values():
node.add_input_port(idx=attr['in'])
control_flow = True if 'control_flow_edge' in attr and attr['control_flow_edge'] is True else False
node.add_input_port(idx=attr['in'], control_flow=control_flow)
# Add out_ports attribute
out_edges = node.out_edges()
out_edges = node.out_edges(control_flow=True)
for attr in out_edges.values():
node.add_output_port(idx=attr['out'])
control_flow = True if 'control_flow_edge' in attr and attr['control_flow_edge'] is True else False
node.add_output_port(idx=attr['out'], control_flow=control_flow)
graph.graph['cmd_params'] = cli
return graph