Fixed order of transformation to convert the TF OD API SSD models (#1887)
* Fixed order of transformation to convert the TF OD API SSD models * Refactored the sub-graph modification for the TF OD API models related to Squeeze/Reshape after SSD heads
This commit is contained in:
parent
c5b19aa8f9
commit
aca452def8
@ -25,6 +25,7 @@ from extensions.front.split_normalizer import SqueezeAxis
|
||||
from extensions.front.standalone_const_eraser import StandaloneConstEraser
|
||||
from extensions.front.tf.CropAndResizeReplacement import CropAndResizeReplacement
|
||||
from extensions.front.tf.FakeQuantWithMinMaxVars import FakeQuantWithMinMaxVarsToQuantize
|
||||
from extensions.front.tf.TFSliceToSlice import TFSliceToSliceReplacer
|
||||
from extensions.front.tf.pad_tf_to_pad import PadTFToPad
|
||||
from extensions.middle.InsertLayoutPropagationTransposes import mark_as_correct_data_layout, \
|
||||
mark_input_as_in_correct_layout, mark_output_as_in_correct_layout
|
||||
@ -43,7 +44,7 @@ from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.extractor import output_user_data_repack, add_output_ops
|
||||
from mo.front.subgraph_matcher import SubgraphMatch
|
||||
from mo.front.tf.graph_utils import add_activation_function_after_node, add_convolution_to_swap_xy_coordinates, \
|
||||
squeeze_reshape_and_concat, add_fake_background_loc, create_op_node_with_second_input
|
||||
mark_squeeze_reshape_concat_before_detection_output, add_fake_background_loc, create_op_node_with_second_input
|
||||
from mo.front.tf.replacement import FrontReplacementFromConfigFileSubGraph, FrontReplacementFromConfigFileGeneral
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.ops.concat import Concat
|
||||
@ -967,7 +968,7 @@ class ObjectDetectionAPISSDPostprocessorReplacement(FrontReplacementFromConfigFi
|
||||
return [ObjectDetectionAPIPreprocessorReplacement, FakeQuantWithMinMaxVarsToQuantize]
|
||||
|
||||
def run_before(self):
|
||||
return [StandaloneConstEraser, TransposeOrderNormalizer]
|
||||
return [StandaloneConstEraser, TransposeOrderNormalizer, TFSliceToSliceReplacer]
|
||||
|
||||
def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
|
||||
# the DetectionOutput in IE produces single tensor, but in TF it produces two tensors, so create only one output
|
||||
@ -1073,7 +1074,7 @@ class ObjectDetectionAPISSDPostprocessorReplacement(FrontReplacementFromConfigFi
|
||||
node.old_infer(node)
|
||||
|
||||
conv_nodes = backward_bfs_for_operation(node.in_node(0), ['Conv2D'])
|
||||
squeeze_reshape_and_concat(conv_nodes)
|
||||
mark_squeeze_reshape_concat_before_detection_output(conv_nodes)
|
||||
|
||||
|
||||
class ObjectDetectionAPIOutputReplacement(FrontReplacementFromConfigFileGeneral):
|
||||
|
@ -64,14 +64,11 @@ def create_op_with_const_inputs(graph: Graph, op: callable, port_value_dict: Dic
|
||||
return node
|
||||
|
||||
|
||||
def squeeze_reshape_and_concat(start_nodes: list):
|
||||
def mark_squeeze_reshape_concat_before_detection_output(start_nodes: list):
|
||||
"""
|
||||
The function looks for Reshape ops after the 'start_nodes' with 4D output and remove the dimension with index 2
|
||||
which should be equal to 1. This is a workaround to make tensor 3D so it's shape will not be transposed during the
|
||||
IR generation. The problem arises when bounding boxes predictions are reshaped from [1, 1, 1, X] to
|
||||
[1, X / 4, 1, 4]. The result tensor should not be transposed because after transpose it will have shape
|
||||
[1, 4, X / 4, 1] and the concatenation over dimension with index 2 will produce incorrect tensor.
|
||||
Also the function looks for Concat ops and change the concat dimension from 2 to 1.
|
||||
The function looks for Reshape, Concat and Squeeze ops after the 'start_nodes' with 4D output and marks them with
|
||||
proper attributes to infer them in original NHWC layout. This is a case of the TensorFlow Object Detection API
|
||||
models for the SSD heads output which produces 4D tensor with bounding box deltas.
|
||||
:param start_nodes: list of nodes to start search from.
|
||||
:return: None
|
||||
"""
|
||||
@ -80,34 +77,28 @@ def squeeze_reshape_and_concat(start_nodes: list):
|
||||
while len(q) != 0:
|
||||
cur_node = q.popleft()
|
||||
if cur_node.has_valid('type'):
|
||||
if cur_node.type == 'DetectionOutput': # do not go beyond the DetectionOutput node
|
||||
if cur_node.soft_get('type') == 'DetectionOutput': # do not go beyond the DetectionOutput node
|
||||
continue
|
||||
if cur_node.op == 'Reshape' and len(cur_node.out_node().shape) == 4:
|
||||
log.debug("Found reshape op with 4D output {}".format(cur_node.id))
|
||||
if cur_node.in_node(1).has_valid('value') and cur_node.in_node(1).value is not None:
|
||||
new_shape = cur_node.in_node(1).value
|
||||
assert new_shape[2] == 1
|
||||
new_shape = np.delete(new_shape, 2)
|
||||
cur_node.in_node(1).value = new_shape
|
||||
cur_node.in_node(1).shape = np.array(new_shape.shape, dtype=np.int64)
|
||||
# run infer function once again
|
||||
cur_node.infer(cur_node)
|
||||
else:
|
||||
log.warning("The reshape size is not defined!")
|
||||
if cur_node.type == 'Concat' and len(cur_node.out_node().shape) == 4:
|
||||
log.debug("Found Concat op with 4D output {}".format(cur_node.id))
|
||||
cur_node.axis = 1
|
||||
# run infer function once again
|
||||
cur_node.infer(cur_node)
|
||||
if cur_node.out_port(0).get_destination().node.op == 'Squeeze':
|
||||
# remove Squeeze node after the Concat
|
||||
squeeze_consumer = cur_node.out_port(0).get_destination().node.out_port(0).get_destination()
|
||||
cur_node.out_port(0).get_connection().set_destination(squeeze_consumer)
|
||||
# the input to Reshape comes from Convolution so it will be converted from NCHW to NHWC layout in the
|
||||
# InsertLayoutPropagationTransposes transformation. But the output should be kept in the original layout
|
||||
if cur_node.soft_get('type') == 'Reshape' and len(cur_node.out_port(0).data.get_shape()) == 4:
|
||||
mark_output_as_in_correct_layout(cur_node, 0)
|
||||
|
||||
out_node_size = len(cur_node.out_nodes())
|
||||
for ind in range(out_node_size):
|
||||
node = cur_node.out_node(ind)
|
||||
q.append(node)
|
||||
# Concat should be inferred in the original layout so the input with concatenation axis should not be
|
||||
# updated from NHWC to NCHW layout
|
||||
if cur_node.soft_get('type') == 'Concat' and len(cur_node.out_port(0).data.get_shape()) == 4:
|
||||
cur_node.in_port(1).__setattr__('input_permutation', None)
|
||||
cur_node['nchw_layout'] = True
|
||||
cur_node.out_node(0)['nchw_layout'] = True
|
||||
|
||||
# Squeeze should be inferred in the original layout so the input with squeeze axis should not be updated
|
||||
# from NHWC to NCHW layout. The input is marked as in correct layout to prevent from inserting Transpose
|
||||
# from NHWC to NCHW.
|
||||
if cur_node.soft_get('type') == 'Squeeze' and len(cur_node.in_port(0).data.get_shape()) == 4:
|
||||
cur_node.in_port(1).__setattr__('input_permutation', None)
|
||||
mark_input_as_in_correct_layout(cur_node, 0)
|
||||
|
||||
[q.append(port.node) for port in cur_node.out_port(0).get_destinations()]
|
||||
|
||||
|
||||
def add_convolution_to_swap_xy_coordinates(graph: Graph, input_node: Node, coordinates_size: int):
|
||||
|
Loading…
Reference in New Issue
Block a user