Fixed order of transformation to convert the TF OD API SSD models (#1887)

* Fixed order of transformation to convert the TF OD API SSD models

* Refactored the sub-graph modification for the TF OD API models related to Squeeze/Reshape after SSD heads
This commit is contained in:
Evgeny Lazarev 2020-08-25 16:39:34 +03:00 committed by GitHub
parent c5b19aa8f9
commit aca452def8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 36 deletions

View File

@ -25,6 +25,7 @@ from extensions.front.split_normalizer import SqueezeAxis
from extensions.front.standalone_const_eraser import StandaloneConstEraser
from extensions.front.tf.CropAndResizeReplacement import CropAndResizeReplacement
from extensions.front.tf.FakeQuantWithMinMaxVars import FakeQuantWithMinMaxVarsToQuantize
from extensions.front.tf.TFSliceToSlice import TFSliceToSliceReplacer
from extensions.front.tf.pad_tf_to_pad import PadTFToPad
from extensions.middle.InsertLayoutPropagationTransposes import mark_as_correct_data_layout, \
mark_input_as_in_correct_layout, mark_output_as_in_correct_layout
@ -43,7 +44,7 @@ from mo.front.common.partial_infer.utils import int64_array
from mo.front.extractor import output_user_data_repack, add_output_ops
from mo.front.subgraph_matcher import SubgraphMatch
from mo.front.tf.graph_utils import add_activation_function_after_node, add_convolution_to_swap_xy_coordinates, \
squeeze_reshape_and_concat, add_fake_background_loc, create_op_node_with_second_input
mark_squeeze_reshape_concat_before_detection_output, add_fake_background_loc, create_op_node_with_second_input
from mo.front.tf.replacement import FrontReplacementFromConfigFileSubGraph, FrontReplacementFromConfigFileGeneral
from mo.graph.graph import Graph, Node
from mo.ops.concat import Concat
@ -967,7 +968,7 @@ class ObjectDetectionAPISSDPostprocessorReplacement(FrontReplacementFromConfigFi
return [ObjectDetectionAPIPreprocessorReplacement, FakeQuantWithMinMaxVarsToQuantize]
def run_before(self):
return [StandaloneConstEraser, TransposeOrderNormalizer]
return [StandaloneConstEraser, TransposeOrderNormalizer, TFSliceToSliceReplacer]
def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
# the DetectionOutput in IE produces single tensor, but in TF it produces two tensors, so create only one output
@ -1073,7 +1074,7 @@ class ObjectDetectionAPISSDPostprocessorReplacement(FrontReplacementFromConfigFi
node.old_infer(node)
conv_nodes = backward_bfs_for_operation(node.in_node(0), ['Conv2D'])
squeeze_reshape_and_concat(conv_nodes)
mark_squeeze_reshape_concat_before_detection_output(conv_nodes)
class ObjectDetectionAPIOutputReplacement(FrontReplacementFromConfigFileGeneral):

View File

@ -64,14 +64,11 @@ def create_op_with_const_inputs(graph: Graph, op: callable, port_value_dict: Dic
return node
def squeeze_reshape_and_concat(start_nodes: list):
def mark_squeeze_reshape_concat_before_detection_output(start_nodes: list):
"""
The function looks for Reshape ops after the 'start_nodes' with 4D output and remove the dimension with index 2
which should be equal to 1. This is a workaround to make tensor 3D so it's shape will not be transposed during the
IR generation. The problem arises when bounding boxes predictions are reshaped from [1, 1, 1, X] to
[1, X / 4, 1, 4]. The result tensor should not be transposed because after transpose it will have shape
[1, 4, X / 4, 1] and the concatenation over dimension with index 2 will produce incorrect tensor.
Also the function looks for Concat ops and change the concat dimension from 2 to 1.
The function looks for Reshape, Concat and Squeeze ops after the 'start_nodes' with 4D output and marks them with
proper attributes to infer them in original NHWC layout. This is a case of the TensorFlow Object Detection API
models for the SSD heads output which produces 4D tensor with bounding box deltas.
:param start_nodes: list of nodes to start search from.
:return: None
"""
@ -80,34 +77,28 @@ def squeeze_reshape_and_concat(start_nodes: list):
while len(q) != 0:
cur_node = q.popleft()
if cur_node.has_valid('type'):
if cur_node.type == 'DetectionOutput': # do not go beyond the DetectionOutput node
if cur_node.soft_get('type') == 'DetectionOutput': # do not go beyond the DetectionOutput node
continue
if cur_node.op == 'Reshape' and len(cur_node.out_node().shape) == 4:
log.debug("Found reshape op with 4D output {}".format(cur_node.id))
if cur_node.in_node(1).has_valid('value') and cur_node.in_node(1).value is not None:
new_shape = cur_node.in_node(1).value
assert new_shape[2] == 1
new_shape = np.delete(new_shape, 2)
cur_node.in_node(1).value = new_shape
cur_node.in_node(1).shape = np.array(new_shape.shape, dtype=np.int64)
# run infer function once again
cur_node.infer(cur_node)
else:
log.warning("The reshape size is not defined!")
if cur_node.type == 'Concat' and len(cur_node.out_node().shape) == 4:
log.debug("Found Concat op with 4D output {}".format(cur_node.id))
cur_node.axis = 1
# run infer function once again
cur_node.infer(cur_node)
if cur_node.out_port(0).get_destination().node.op == 'Squeeze':
# remove Squeeze node after the Concat
squeeze_consumer = cur_node.out_port(0).get_destination().node.out_port(0).get_destination()
cur_node.out_port(0).get_connection().set_destination(squeeze_consumer)
# the input to Reshape comes from Convolution so it will be converted from NCHW to NHWC layout in the
# InsertLayoutPropagationTransposes transformation. But the output should be kept in the original layout
if cur_node.soft_get('type') == 'Reshape' and len(cur_node.out_port(0).data.get_shape()) == 4:
mark_output_as_in_correct_layout(cur_node, 0)
out_node_size = len(cur_node.out_nodes())
for ind in range(out_node_size):
node = cur_node.out_node(ind)
q.append(node)
# Concat should be inferred in the original layout so the input with concatenation axis should not be
# updated from NHWC to NCHW layout
if cur_node.soft_get('type') == 'Concat' and len(cur_node.out_port(0).data.get_shape()) == 4:
cur_node.in_port(1).__setattr__('input_permutation', None)
cur_node['nchw_layout'] = True
cur_node.out_node(0)['nchw_layout'] = True
# Squeeze should be inferred in the original layout so the input with squeeze axis should not be updated
# from NHWC to NCHW layout. The input is marked as in correct layout to prevent from inserting Transpose
# from NHWC to NCHW.
if cur_node.soft_get('type') == 'Squeeze' and len(cur_node.in_port(0).data.get_shape()) == 4:
cur_node.in_port(1).__setattr__('input_permutation', None)
mark_input_as_in_correct_layout(cur_node, 0)
[q.append(port.node) for port in cur_node.out_port(0).get_destinations()]
def add_convolution_to_swap_xy_coordinates(graph: Graph, input_node: Node, coordinates_size: int):