Fix of ObjectDetectionAPIProposalReplacement(). (#14869)
* Fixed ObjectDetectionAPIProposalReplacement() to get correct CropAndResize node. * Small correction. * Moved topological sort with start node to separate method, added tests. * Simplified code.
This commit is contained in:
parent
931fd11eee
commit
e2943c2430
@ -1521,8 +1521,11 @@ class ObjectDetectionAPIProposalReplacement(FrontReplacementFromConfigFileSubGra
|
||||
dict(name="reshape_swap_proposals_2d"), proposal)
|
||||
mark_input_as_in_correct_layout(proposal_reshape_2d, 0)
|
||||
|
||||
crop_and_resize_nodes_ids = [node_id for node_id in bfs_search(graph, [match.single_input_node(0)[0].id]) if
|
||||
graph.nodes[node_id]['op'] == 'CropAndResize']
|
||||
# Find closest CropAndResize in topological order
|
||||
start_node = match.single_input_node(0)[0]
|
||||
crop_and_resize_nodes_ids = [node.id for node in graph.pseudo_topological_sort_with_start_node(start_node) if
|
||||
graph.nodes[node.id]['op'] == 'CropAndResize']
|
||||
|
||||
if len(crop_and_resize_nodes_ids) != 0 and swap_proposals:
|
||||
# feed the CropAndResize node with a correct boxes information produced with the Proposal layer
|
||||
# find the first CropAndResize node in the BFS order. This is needed in the case when we already swapped
|
||||
|
@ -988,6 +988,18 @@ class Graph(nx.MultiDiGraph):
|
||||
else:
|
||||
return list(reversed(order))
|
||||
|
||||
def pseudo_topological_sort_with_start_node(self, start_node: Node, reverse: bool = False):
|
||||
nodes_without_inputs = [start_node.soft_get('name')]
|
||||
visited = set()
|
||||
order = self.dfs(nodes_without_inputs[0], visited)
|
||||
|
||||
order = [Node(self, node) for node in order]
|
||||
|
||||
if reverse:
|
||||
return order
|
||||
else:
|
||||
return list(reversed(order))
|
||||
|
||||
def clean_up(self, undead_node_types: list = None):
|
||||
if undead_node_types is None:
|
||||
undead_node_types = []
|
||||
|
@ -1793,3 +1793,75 @@ class TestGetSetAttributeBetweenNodes(unittest.TestCase):
|
||||
self.assertTrue(get_edge_attribute_between_nodes(a_node, f_node, 'Attr') == "new_value_3")
|
||||
self.assertTrue(get_edge_attribute_between_nodes(b_node, d_node, 'Attr') == "new_value_4")
|
||||
self.assertTrue(get_edge_attribute_between_nodes(b_node, f_node, 'Attr') == "new_value_5")
|
||||
|
||||
|
||||
class TestTopologicalSort(unittest.TestCase):
|
||||
nodes = {
|
||||
'A': {'id': 0, 'kind': 'op'},
|
||||
'B': {'id': 1, 'kind': 'op'},
|
||||
'C': {'id': 2, 'kind': 'op'},
|
||||
'D': {'id': 3, 'kind': 'op'},
|
||||
'E': {'id': 4, 'kind': 'op'},
|
||||
}
|
||||
|
||||
def build_test_graph(self):
|
||||
graph = build_graph(self.nodes, [
|
||||
('A', 'B', {'in': 0, 'out': 0}),
|
||||
('A', 'C', {'in': 0, 'out': 1}),
|
||||
('A', 'D', {'in': 0, 'out': 2}),
|
||||
('A', 'E', {'in': 0, 'out': 3}),
|
||||
('B', 'D', {'in': 1, 'out': 0}),
|
||||
('C', 'D', {'in': 2, 'out': 0}),
|
||||
('C', 'E', {'in': 1, 'out': 1}),
|
||||
('D', 'E', {'in': 2, 'out': 0}),
|
||||
])
|
||||
return graph
|
||||
|
||||
def test_sort_with_start_node(self):
|
||||
graph = self.build_test_graph()
|
||||
|
||||
stat_node = Node(graph, "A")
|
||||
nodes_names = [node.name for node in graph.pseudo_topological_sort_with_start_node(start_node=stat_node)]
|
||||
assert nodes_names == ['A', 'C', 'B', 'D', 'E']
|
||||
|
||||
stat_node = Node(graph, "B")
|
||||
nodes_names = [node.name for node in graph.pseudo_topological_sort_with_start_node(start_node=stat_node)]
|
||||
assert nodes_names == ['B', 'D', 'E']
|
||||
|
||||
stat_node = Node(graph, "C")
|
||||
nodes_names = [node.name for node in graph.pseudo_topological_sort_with_start_node(start_node=stat_node)]
|
||||
assert nodes_names == ['C', 'D', 'E']
|
||||
|
||||
stat_node = Node(graph, "D")
|
||||
nodes_names = [node.name for node in graph.pseudo_topological_sort_with_start_node(start_node=stat_node)]
|
||||
assert nodes_names == ['D', 'E']
|
||||
|
||||
stat_node = Node(graph, "E")
|
||||
nodes_names = [node.name for node in graph.pseudo_topological_sort_with_start_node(start_node=stat_node)]
|
||||
assert nodes_names == ['E']
|
||||
|
||||
# reverse order
|
||||
stat_node = Node(graph, "A")
|
||||
nodes_names = [node.name for node in graph.pseudo_topological_sort_with_start_node(start_node=stat_node,
|
||||
reverse=True)]
|
||||
assert nodes_names == ['E', 'D', 'B', 'C', 'A']
|
||||
|
||||
stat_node = Node(graph, "B")
|
||||
nodes_names = [node.name for node in graph.pseudo_topological_sort_with_start_node(start_node=stat_node,
|
||||
reverse=True)]
|
||||
assert nodes_names == ['E', 'D', 'B']
|
||||
|
||||
stat_node = Node(graph, "C")
|
||||
nodes_names = [node.name for node in graph.pseudo_topological_sort_with_start_node(start_node=stat_node,
|
||||
reverse=True)]
|
||||
assert nodes_names == ['E', 'D', 'C']
|
||||
|
||||
stat_node = Node(graph, "D")
|
||||
nodes_names = [node.name for node in graph.pseudo_topological_sort_with_start_node(start_node=stat_node,
|
||||
reverse=True)]
|
||||
assert nodes_names == ['E', 'D']
|
||||
|
||||
stat_node = Node(graph, "E")
|
||||
nodes_names = [node.name for node in graph.pseudo_topological_sort_with_start_node(start_node=stat_node,
|
||||
reverse=True)]
|
||||
assert nodes_names == ['E']
|
Loading…
Reference in New Issue
Block a user