Updated conversion of TF OD API 2.4 SSD models (#6473)
* Updated conversion of TF OD API 2.4 SSD models * Fixed issue when more Conv2D nodes were selected for weights permutation when converting TF OD API models * Code style fixes * Fixed code comments
This commit is contained in:
@@ -955,7 +955,8 @@ class ObjectDetectionAPIDetectionOutputReplacement(FrontReplacementFromConfigFil
|
||||
detection_output_node.name = 'detection_output'
|
||||
|
||||
if coordinates_swap_method == 'swap_weights':
|
||||
swap_weights_xy(graph, backward_bfs_for_operation(detection_output_node.in_node(0), ['MatMul', 'Conv2D']))
|
||||
swap_weights_xy(graph, backward_bfs_for_operation(detection_output_node.in_node(0), ['MatMul', 'Conv2D'],
|
||||
['ShapeOf']))
|
||||
|
||||
# when the use_matmul_crop_and_resize = True then the prior boxes were not swapped and we need to swap them from
|
||||
# YXYX to XYXY before passing to the DetectionOutput operation
|
||||
@@ -1312,7 +1313,7 @@ class ObjectDetectionAPISSDPostprocessorReplacement(FrontReplacementFromConfigFi
|
||||
|
||||
# compared to the IE's DetectionOutput, the TF keeps the locations in YXYX, need to get back to the XYXY
|
||||
# for last convolutions that operate the locations need to swap the X and Y for output feature weights & biases
|
||||
conv_nodes = backward_bfs_for_operation(detection_output_node.in_node(0), ['Conv2D'])
|
||||
conv_nodes = backward_bfs_for_operation(detection_output_node.in_node(0), ['Conv2D'], ['ShapeOf'])
|
||||
swap_weights_xy(graph, conv_nodes)
|
||||
|
||||
# As outputs are replaced with a postprocessing node, outgoing tensor names are no longer
|
||||
@@ -1355,7 +1356,7 @@ class ObjectDetectionAPISSDPostprocessorReplacement(FrontReplacementFromConfigFi
|
||||
|
||||
node.old_infer(node)
|
||||
|
||||
conv_nodes = backward_bfs_for_operation(node.in_node(0), ['Conv2D'])
|
||||
conv_nodes = backward_bfs_for_operation(node.in_node(0), ['Conv2D'], ['ShapeOf'])
|
||||
mark_squeeze_reshape_concat_before_detection_output(conv_nodes)
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
[
|
||||
{
|
||||
"custom_attributes": {
|
||||
"start_nodes": ["StatefulPartitionedCall/map/TensorArrayUnstack/TensorListFromTensor"],
|
||||
"start_nodes": ["StatefulPartitionedCall/map/TensorArrayUnstack/TensorListFromTensor",
|
||||
"StatefulPartitionedCall/map/Shape"],
|
||||
"end_nodes": ["StatefulPartitionedCall/map/TensorArrayV2Stack/TensorListStack",
|
||||
"StatefulPartitionedCall/map/TensorArrayV2Stack_1/TensorListStack"]
|
||||
},
|
||||
@@ -32,10 +33,11 @@
|
||||
"StatefulPartitionedCall/Identity_7"
|
||||
],
|
||||
"start_points": [
|
||||
"StatefulPartitionedCall/Postprocessor/Reshape_1",
|
||||
"StatefulPartitionedCall/Postprocessor/raw_box_encodings",
|
||||
"StatefulPartitionedCall/Postprocessor/scale_logits",
|
||||
"StatefulPartitionedCall/Postprocessor/Tile",
|
||||
"StatefulPartitionedCall/Postprocessor/Cast_1"
|
||||
"StatefulPartitionedCall/Postprocessor/Cast_1",
|
||||
"StatefulPartitionedCall/Postprocessor/Cast"
|
||||
]
|
||||
},
|
||||
"match_kind": "points"
|
||||
|
||||
@@ -12,7 +12,7 @@ from mo.utils.error import Error
|
||||
from mo.utils.utils import refer_to_faq_msg
|
||||
|
||||
|
||||
def backward_bfs_for_operation(start_node: Node, op_names: list):
|
||||
def backward_bfs_for_operation(start_node: Node, op_names: list, skip_op_list: list = None):
|
||||
"""
|
||||
Find node with 'op' attribute equal to one of from 'op_name', searching in the backward direction.
|
||||
In case of branching algorithm goes into each branch, but if it can't find layer in one of them it returns
|
||||
@@ -20,7 +20,10 @@ def backward_bfs_for_operation(start_node: Node, op_names: list):
|
||||
|
||||
:param start_node: Start node for BFS algorithm
|
||||
:param op_names: The list with names of operations to search
|
||||
:param skip_op_list: list of operations to be stopped at if they are met
|
||||
"""
|
||||
if skip_op_list is None:
|
||||
skip_op_list = []
|
||||
ret = []
|
||||
q = deque([start_node])
|
||||
while len(q) != 0:
|
||||
@@ -33,7 +36,8 @@ def backward_bfs_for_operation(start_node: Node, op_names: list):
|
||||
if pnode.id not in ret:
|
||||
ret.append(pnode.id)
|
||||
else:
|
||||
q.append(pnode)
|
||||
if pnode.op not in skip_op_list:
|
||||
q.append(pnode)
|
||||
elif pnode.kind == 'data' and pnode.value is None:
|
||||
q.append(pnode)
|
||||
return [Node(start_node.graph, x) for x in ret]
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from mo.graph.graph import Graph
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.utils.error import Error
|
||||
from mo.utils.graph import bfs_search, is_connected_component, sub_graph_between_nodes
|
||||
from mo.utils.graph import bfs_search, is_connected_component, sub_graph_between_nodes, backward_bfs_for_operation
|
||||
from unit_tests.utils.graph import regular_op, result, build_graph_with_edge_attrs
|
||||
|
||||
|
||||
class TestGraphUtils(unittest.TestCase):
|
||||
@@ -254,3 +255,109 @@ class TestGraphUtils(unittest.TestCase):
|
||||
sub_graph_nodes = sub_graph_between_nodes(graph, [1], [4], include_control_flow=False)
|
||||
self.assertIsNotNone(sub_graph_nodes)
|
||||
self.assertListEqual(sorted(sub_graph_nodes), sorted([1, 2, 3, 4]))
|
||||
|
||||
def test_backward_bfs_for_op_no_ops_detected(self):
|
||||
nodes = {**regular_op('input', {'op': 'Parameter'}),
|
||||
**regular_op('hsigmoid', {'op': 'HSigmoid'}),
|
||||
**result('result'),
|
||||
}
|
||||
edges = [('input', 'hsigmoid', {'out': 0, 'in': 0}),
|
||||
('hsigmoid', 'result', {'out': 0, 'in': 0}),
|
||||
]
|
||||
|
||||
graph = build_graph_with_edge_attrs(nodes, edges)
|
||||
graph.stage = 'front'
|
||||
|
||||
found_nodes = backward_bfs_for_operation(Node(graph, 'result'), ['NonExistingOp'])
|
||||
self.assertEqual(len(found_nodes), 0)
|
||||
|
||||
def test_backward_bfs_for_op_closest_op_detected(self):
|
||||
"""
|
||||
input -> hsigmoid_1 -> hsigmoid_2 -> result
|
||||
The returned op should be first met HSigmoid which is hsigmoid_2
|
||||
"""
|
||||
nodes = {**regular_op('input', {'op': 'Parameter'}),
|
||||
**regular_op('hsigmoid_1', {'op': 'HSigmoid'}),
|
||||
**regular_op('hsigmoid_2', {'op': 'HSigmoid'}),
|
||||
**result('result'),
|
||||
}
|
||||
edges = [('input', 'hsigmoid_1', {'out': 0, 'in': 0}),
|
||||
('hsigmoid_1', 'hsigmoid_2', {'out': 0, 'in': 0}),
|
||||
('hsigmoid_2', 'result', {'out': 0, 'in': 0}),
|
||||
]
|
||||
|
||||
graph = build_graph_with_edge_attrs(nodes, edges)
|
||||
graph.stage = 'front'
|
||||
|
||||
found_nodes = backward_bfs_for_operation(Node(graph, 'result'), ['HSigmoid'])
|
||||
self.assertEqual(len(found_nodes), 1)
|
||||
self.assertEqual(found_nodes[0].id, 'hsigmoid_2')
|
||||
|
||||
def test_backward_bfs_for_op_parallel_branch_op_detected(self):
|
||||
r"""
|
||||
input_1 -> hsigmoid_1 -> hsigmoid_2 ->
|
||||
\
|
||||
- Concat->result
|
||||
/
|
||||
input_2 -> hsigmoid_3 -> hsigmoid_4 ->
|
||||
The returned op should be first met HSigmoids which are hsigmoid_2 and hsigmoid_4
|
||||
"""
|
||||
nodes = {**regular_op('input_1', {'op': 'Parameter'}),
|
||||
**regular_op('hsigmoid_1', {'op': 'HSigmoid'}),
|
||||
**regular_op('hsigmoid_2', {'op': 'HSigmoid'}),
|
||||
**regular_op('input_2', {'op': 'Parameter'}),
|
||||
**regular_op('hsigmoid_3', {'op': 'HSigmoid'}),
|
||||
**regular_op('hsigmoid_4', {'op': 'HSigmoid'}),
|
||||
**regular_op('concat', {'op': 'Concat'}),
|
||||
**result('result'),
|
||||
}
|
||||
edges = [('input_1', 'hsigmoid_1', {'out': 0, 'in': 0}),
|
||||
('hsigmoid_1', 'hsigmoid_2', {'out': 0, 'in': 0}),
|
||||
('hsigmoid_2', 'concat', {'out': 0, 'in': 0}),
|
||||
('input_2', 'hsigmoid_3', {'out': 0, 'in': 0}),
|
||||
('hsigmoid_3', 'hsigmoid_4', {'out': 0, 'in': 0}),
|
||||
('hsigmoid_4', 'concat', {'out': 0, 'in': 1}),
|
||||
('concat', 'result', {'out': 0, 'in': 0}),
|
||||
]
|
||||
|
||||
graph = build_graph_with_edge_attrs(nodes, edges)
|
||||
graph.stage = 'front'
|
||||
|
||||
found_nodes = backward_bfs_for_operation(Node(graph, 'result'), ['HSigmoid'])
|
||||
self.assertEqual(len(found_nodes), 2)
|
||||
self.assertSetEqual({found_nodes[0].id, found_nodes[1].id}, {'hsigmoid_2', 'hsigmoid_4'})
|
||||
|
||||
def test_backward_bfs_for_op_parallel_branch_stop_op(self):
|
||||
r"""
|
||||
input_1 -> hsigmoid_1 -> hsigmoid_2 ->
|
||||
\
|
||||
- Concat->result
|
||||
/
|
||||
input_2 -> hsigmoid_3 -> ShapeOf ->
|
||||
The returned op should be first met HSigmoids which is hsigmoid_2, but not the hsigmoid_3 located after banned
|
||||
operation of type "ShapeOf"
|
||||
"""
|
||||
nodes = {**regular_op('input_1', {'op': 'Parameter'}),
|
||||
**regular_op('hsigmoid_1', {'op': 'HSigmoid'}),
|
||||
**regular_op('hsigmoid_2', {'op': 'HSigmoid'}),
|
||||
**regular_op('input_2', {'op': 'Parameter'}),
|
||||
**regular_op('hsigmoid_3', {'op': 'HSigmoid'}),
|
||||
**regular_op('shapeof', {'op': 'ShapeOf'}),
|
||||
**regular_op('concat', {'op': 'Concat'}),
|
||||
**result('result'),
|
||||
}
|
||||
edges = [('input_1', 'hsigmoid_1', {'out': 0, 'in': 0}),
|
||||
('hsigmoid_1', 'hsigmoid_2', {'out': 0, 'in': 0}),
|
||||
('hsigmoid_2', 'concat', {'out': 0, 'in': 0}),
|
||||
('input_2', 'hsigmoid_3', {'out': 0, 'in': 0}),
|
||||
('hsigmoid_3', 'shapeof', {'out': 0, 'in': 0}),
|
||||
('shapeof', 'concat', {'out': 0, 'in': 1}),
|
||||
('concat', 'result', {'out': 0, 'in': 0}),
|
||||
]
|
||||
|
||||
graph = build_graph_with_edge_attrs(nodes, edges)
|
||||
graph.stage = 'front'
|
||||
|
||||
found_nodes = backward_bfs_for_operation(Node(graph, 'result'), ['HSigmoid'], ['ShapeOf'])
|
||||
self.assertEqual(len(found_nodes), 1)
|
||||
self.assertEqual(found_nodes[0].id, 'hsigmoid_2')
|
||||
|
||||
Reference in New Issue
Block a user