Fix of ObjectDetectionAPIProposalReplacement(). (#14869)

* Fixed ObjectDetectionAPIProposalReplacement() to get correct CropAndResize node.

* Small correction.

* Moved topological sort with start node to separate method, added tests.

* Simplified code.
This commit is contained in:
Anastasiia Pnevskaia 2023-01-23 10:19:40 +01:00 committed by GitHub
parent 931fd11eee
commit e2943c2430
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 89 additions and 2 deletions

View File

@ -1521,8 +1521,11 @@ class ObjectDetectionAPIProposalReplacement(FrontReplacementFromConfigFileSubGra
dict(name="reshape_swap_proposals_2d"), proposal)
mark_input_as_in_correct_layout(proposal_reshape_2d, 0)
crop_and_resize_nodes_ids = [node_id for node_id in bfs_search(graph, [match.single_input_node(0)[0].id]) if
graph.nodes[node_id]['op'] == 'CropAndResize']
# Find closest CropAndResize in topological order
start_node = match.single_input_node(0)[0]
crop_and_resize_nodes_ids = [node.id for node in graph.pseudo_topological_sort_with_start_node(start_node) if
graph.nodes[node.id]['op'] == 'CropAndResize']
if len(crop_and_resize_nodes_ids) != 0 and swap_proposals:
# feed the CropAndResize node with a correct boxes information produced with the Proposal layer
# find the first CropAndResize node in the BFS order. This is needed in the case when we already swapped

View File

@ -988,6 +988,18 @@ class Graph(nx.MultiDiGraph):
else:
return list(reversed(order))
def pseudo_topological_sort_with_start_node(self, start_node: Node, reverse: bool = False):
nodes_without_inputs = [start_node.soft_get('name')]
visited = set()
order = self.dfs(nodes_without_inputs[0], visited)
order = [Node(self, node) for node in order]
if reverse:
return order
else:
return list(reversed(order))
def clean_up(self, undead_node_types: list = None):
if undead_node_types is None:
undead_node_types = []

View File

@ -1793,3 +1793,75 @@ class TestGetSetAttributeBetweenNodes(unittest.TestCase):
self.assertTrue(get_edge_attribute_between_nodes(a_node, f_node, 'Attr') == "new_value_3")
self.assertTrue(get_edge_attribute_between_nodes(b_node, d_node, 'Attr') == "new_value_4")
self.assertTrue(get_edge_attribute_between_nodes(b_node, f_node, 'Attr') == "new_value_5")
class TestTopologicalSort(unittest.TestCase):
nodes = {
'A': {'id': 0, 'kind': 'op'},
'B': {'id': 1, 'kind': 'op'},
'C': {'id': 2, 'kind': 'op'},
'D': {'id': 3, 'kind': 'op'},
'E': {'id': 4, 'kind': 'op'},
}
def build_test_graph(self):
graph = build_graph(self.nodes, [
('A', 'B', {'in': 0, 'out': 0}),
('A', 'C', {'in': 0, 'out': 1}),
('A', 'D', {'in': 0, 'out': 2}),
('A', 'E', {'in': 0, 'out': 3}),
('B', 'D', {'in': 1, 'out': 0}),
('C', 'D', {'in': 2, 'out': 0}),
('C', 'E', {'in': 1, 'out': 1}),
('D', 'E', {'in': 2, 'out': 0}),
])
return graph
def test_sort_with_start_node(self):
graph = self.build_test_graph()
stat_node = Node(graph, "A")
nodes_names = [node.name for node in graph.pseudo_topological_sort_with_start_node(start_node=stat_node)]
assert nodes_names == ['A', 'C', 'B', 'D', 'E']
stat_node = Node(graph, "B")
nodes_names = [node.name for node in graph.pseudo_topological_sort_with_start_node(start_node=stat_node)]
assert nodes_names == ['B', 'D', 'E']
stat_node = Node(graph, "C")
nodes_names = [node.name for node in graph.pseudo_topological_sort_with_start_node(start_node=stat_node)]
assert nodes_names == ['C', 'D', 'E']
stat_node = Node(graph, "D")
nodes_names = [node.name for node in graph.pseudo_topological_sort_with_start_node(start_node=stat_node)]
assert nodes_names == ['D', 'E']
stat_node = Node(graph, "E")
nodes_names = [node.name for node in graph.pseudo_topological_sort_with_start_node(start_node=stat_node)]
assert nodes_names == ['E']
# reverse order
stat_node = Node(graph, "A")
nodes_names = [node.name for node in graph.pseudo_topological_sort_with_start_node(start_node=stat_node,
reverse=True)]
assert nodes_names == ['E', 'D', 'B', 'C', 'A']
stat_node = Node(graph, "B")
nodes_names = [node.name for node in graph.pseudo_topological_sort_with_start_node(start_node=stat_node,
reverse=True)]
assert nodes_names == ['E', 'D', 'B']
stat_node = Node(graph, "C")
nodes_names = [node.name for node in graph.pseudo_topological_sort_with_start_node(start_node=stat_node,
reverse=True)]
assert nodes_names == ['E', 'D', 'C']
stat_node = Node(graph, "D")
nodes_names = [node.name for node in graph.pseudo_topological_sort_with_start_node(start_node=stat_node,
reverse=True)]
assert nodes_names == ['E', 'D']
stat_node = Node(graph, "E")
nodes_names = [node.name for node in graph.pseudo_topological_sort_with_start_node(start_node=stat_node,
reverse=True)]
assert nodes_names == ['E']