Updated conversion of TF OD API 2.4 SSD models (#6473)

* Updated conversion of TF OD API 2.4 SSD models

* Fixed issue when more Conv2D nodes were selected for weights permutation when converting TF OD API models

* Code style fixes

* Fixed code comments
This commit is contained in:
Evgeny Lazarev
2021-07-02 17:35:59 +03:00
committed by GitHub
parent 8cc1737b5d
commit 21d060ac2b
4 changed files with 124 additions and 10 deletions

View File

@@ -955,7 +955,8 @@ class ObjectDetectionAPIDetectionOutputReplacement(FrontReplacementFromConfigFil
detection_output_node.name = 'detection_output'
if coordinates_swap_method == 'swap_weights':
swap_weights_xy(graph, backward_bfs_for_operation(detection_output_node.in_node(0), ['MatMul', 'Conv2D']))
swap_weights_xy(graph, backward_bfs_for_operation(detection_output_node.in_node(0), ['MatMul', 'Conv2D'],
['ShapeOf']))
# when the use_matmul_crop_and_resize = True then the prior boxes were not swapped and we need to swap them from
# YXYX to XYXY before passing to the DetectionOutput operation
@@ -1312,7 +1313,7 @@ class ObjectDetectionAPISSDPostprocessorReplacement(FrontReplacementFromConfigFi
# compared to the IE's DetectionOutput, the TF keeps the locations in YXYX, need to get back to the XYXY
# for last convolutions that operate the locations need to swap the X and Y for output feature weights & biases
conv_nodes = backward_bfs_for_operation(detection_output_node.in_node(0), ['Conv2D'])
conv_nodes = backward_bfs_for_operation(detection_output_node.in_node(0), ['Conv2D'], ['ShapeOf'])
swap_weights_xy(graph, conv_nodes)
# As outputs are replaced with a postprocessing node, outgoing tensor names are no longer
@@ -1355,7 +1356,7 @@ class ObjectDetectionAPISSDPostprocessorReplacement(FrontReplacementFromConfigFi
node.old_infer(node)
conv_nodes = backward_bfs_for_operation(node.in_node(0), ['Conv2D'])
conv_nodes = backward_bfs_for_operation(node.in_node(0), ['Conv2D'], ['ShapeOf'])
mark_squeeze_reshape_concat_before_detection_output(conv_nodes)

View File

@@ -1,7 +1,8 @@
[
{
"custom_attributes": {
"start_nodes": ["StatefulPartitionedCall/map/TensorArrayUnstack/TensorListFromTensor"],
"start_nodes": ["StatefulPartitionedCall/map/TensorArrayUnstack/TensorListFromTensor",
"StatefulPartitionedCall/map/Shape"],
"end_nodes": ["StatefulPartitionedCall/map/TensorArrayV2Stack/TensorListStack",
"StatefulPartitionedCall/map/TensorArrayV2Stack_1/TensorListStack"]
},
@@ -32,10 +33,11 @@
"StatefulPartitionedCall/Identity_7"
],
"start_points": [
"StatefulPartitionedCall/Postprocessor/Reshape_1",
"StatefulPartitionedCall/Postprocessor/raw_box_encodings",
"StatefulPartitionedCall/Postprocessor/scale_logits",
"StatefulPartitionedCall/Postprocessor/Tile",
"StatefulPartitionedCall/Postprocessor/Cast_1"
"StatefulPartitionedCall/Postprocessor/Cast_1",
"StatefulPartitionedCall/Postprocessor/Cast"
]
},
"match_kind": "points"

View File

@@ -12,7 +12,7 @@ from mo.utils.error import Error
from mo.utils.utils import refer_to_faq_msg
def backward_bfs_for_operation(start_node: Node, op_names: list):
def backward_bfs_for_operation(start_node: Node, op_names: list, skip_op_list: list = None):
"""
Find node with 'op' attribute equal to one of from 'op_name', searching in the backward direction.
In case of branching algorithm goes into each branch, but if it can't find layer in one of them it returns
@@ -20,7 +20,10 @@ def backward_bfs_for_operation(start_node: Node, op_names: list):
:param start_node: Start node for BFS algorithm
:param op_names: The list with names of operations to search
:param skip_op_list: list of operations to be stopped at if they are met
"""
if skip_op_list is None:
skip_op_list = []
ret = []
q = deque([start_node])
while len(q) != 0:
@@ -33,7 +36,8 @@ def backward_bfs_for_operation(start_node: Node, op_names: list):
if pnode.id not in ret:
ret.append(pnode.id)
else:
q.append(pnode)
if pnode.op not in skip_op_list:
q.append(pnode)
elif pnode.kind == 'data' and pnode.value is None:
q.append(pnode)
return [Node(start_node.graph, x) for x in ret]

View File

@@ -3,9 +3,10 @@
import unittest
from mo.graph.graph import Graph
from mo.graph.graph import Graph, Node
from mo.utils.error import Error
from mo.utils.graph import bfs_search, is_connected_component, sub_graph_between_nodes
from mo.utils.graph import bfs_search, is_connected_component, sub_graph_between_nodes, backward_bfs_for_operation
from unit_tests.utils.graph import regular_op, result, build_graph_with_edge_attrs
class TestGraphUtils(unittest.TestCase):
@@ -254,3 +255,109 @@ class TestGraphUtils(unittest.TestCase):
sub_graph_nodes = sub_graph_between_nodes(graph, [1], [4], include_control_flow=False)
self.assertIsNotNone(sub_graph_nodes)
self.assertListEqual(sorted(sub_graph_nodes), sorted([1, 2, 3, 4]))
def test_backward_bfs_for_op_no_ops_detected(self):
nodes = {**regular_op('input', {'op': 'Parameter'}),
**regular_op('hsigmoid', {'op': 'HSigmoid'}),
**result('result'),
}
edges = [('input', 'hsigmoid', {'out': 0, 'in': 0}),
('hsigmoid', 'result', {'out': 0, 'in': 0}),
]
graph = build_graph_with_edge_attrs(nodes, edges)
graph.stage = 'front'
found_nodes = backward_bfs_for_operation(Node(graph, 'result'), ['NonExistingOp'])
self.assertEqual(len(found_nodes), 0)
def test_backward_bfs_for_op_closest_op_detected(self):
"""
input -> hsigmoid_1 -> hsigmoid_2 -> result
The returned op should be first met HSigmoid which is hsigmoid_2
"""
nodes = {**regular_op('input', {'op': 'Parameter'}),
**regular_op('hsigmoid_1', {'op': 'HSigmoid'}),
**regular_op('hsigmoid_2', {'op': 'HSigmoid'}),
**result('result'),
}
edges = [('input', 'hsigmoid_1', {'out': 0, 'in': 0}),
('hsigmoid_1', 'hsigmoid_2', {'out': 0, 'in': 0}),
('hsigmoid_2', 'result', {'out': 0, 'in': 0}),
]
graph = build_graph_with_edge_attrs(nodes, edges)
graph.stage = 'front'
found_nodes = backward_bfs_for_operation(Node(graph, 'result'), ['HSigmoid'])
self.assertEqual(len(found_nodes), 1)
self.assertEqual(found_nodes[0].id, 'hsigmoid_2')
def test_backward_bfs_for_op_parallel_branch_op_detected(self):
r"""
input_1 -> hsigmoid_1 -> hsigmoid_2 ->
\
- Concat->result
/
input_2 -> hsigmoid_3 -> hsigmoid_4 ->
The returned op should be first met HSigmoids which are hsigmoid_2 and hsigmoid_4
"""
nodes = {**regular_op('input_1', {'op': 'Parameter'}),
**regular_op('hsigmoid_1', {'op': 'HSigmoid'}),
**regular_op('hsigmoid_2', {'op': 'HSigmoid'}),
**regular_op('input_2', {'op': 'Parameter'}),
**regular_op('hsigmoid_3', {'op': 'HSigmoid'}),
**regular_op('hsigmoid_4', {'op': 'HSigmoid'}),
**regular_op('concat', {'op': 'Concat'}),
**result('result'),
}
edges = [('input_1', 'hsigmoid_1', {'out': 0, 'in': 0}),
('hsigmoid_1', 'hsigmoid_2', {'out': 0, 'in': 0}),
('hsigmoid_2', 'concat', {'out': 0, 'in': 0}),
('input_2', 'hsigmoid_3', {'out': 0, 'in': 0}),
('hsigmoid_3', 'hsigmoid_4', {'out': 0, 'in': 0}),
('hsigmoid_4', 'concat', {'out': 0, 'in': 1}),
('concat', 'result', {'out': 0, 'in': 0}),
]
graph = build_graph_with_edge_attrs(nodes, edges)
graph.stage = 'front'
found_nodes = backward_bfs_for_operation(Node(graph, 'result'), ['HSigmoid'])
self.assertEqual(len(found_nodes), 2)
self.assertSetEqual({found_nodes[0].id, found_nodes[1].id}, {'hsigmoid_2', 'hsigmoid_4'})
def test_backward_bfs_for_op_parallel_branch_stop_op(self):
r"""
input_1 -> hsigmoid_1 -> hsigmoid_2 ->
\
- Concat->result
/
input_2 -> hsigmoid_3 -> ShapeOf ->
The returned op should be first met HSigmoids which is hsigmoid_2, but not the hsigmoid_3 located after banned
operation of type "ShapeOf"
"""
nodes = {**regular_op('input_1', {'op': 'Parameter'}),
**regular_op('hsigmoid_1', {'op': 'HSigmoid'}),
**regular_op('hsigmoid_2', {'op': 'HSigmoid'}),
**regular_op('input_2', {'op': 'Parameter'}),
**regular_op('hsigmoid_3', {'op': 'HSigmoid'}),
**regular_op('shapeof', {'op': 'ShapeOf'}),
**regular_op('concat', {'op': 'Concat'}),
**result('result'),
}
edges = [('input_1', 'hsigmoid_1', {'out': 0, 'in': 0}),
('hsigmoid_1', 'hsigmoid_2', {'out': 0, 'in': 0}),
('hsigmoid_2', 'concat', {'out': 0, 'in': 0}),
('input_2', 'hsigmoid_3', {'out': 0, 'in': 0}),
('hsigmoid_3', 'shapeof', {'out': 0, 'in': 0}),
('shapeof', 'concat', {'out': 0, 'in': 1}),
('concat', 'result', {'out': 0, 'in': 0}),
]
graph = build_graph_with_edge_attrs(nodes, edges)
graph.stage = 'front'
found_nodes = backward_bfs_for_operation(Node(graph, 'result'), ['HSigmoid'], ['ShapeOf'])
self.assertEqual(len(found_nodes), 1)
self.assertEqual(found_nodes[0].id, 'hsigmoid_2')