Added support for a new version of the TF OD API pre-processing part (#3063)

* Added support for a new version of the TF OD API pre-processing part of the mode

* Get rid of legacy API usage

* Fix comment and added assert

* Wording
This commit is contained in:
Evgeny Lazarev 2020-11-11 11:53:10 +03:00 committed by GitHub
parent 6b09d5769f
commit f633f0035c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -486,6 +486,23 @@ class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSu
new_nodes_to_remove.remove(node_to_keep)
return new_nodes_to_remove
def is_preprocessing_applied_before_resize(self, to_float: Node, mul: Node, sub: Node):
"""
The function checks if the output of 'to_float' operation is consumed by 'mul' or 'sub'. If this is true then
the pre-processing (mean/scale) is applied before the image resize. The image resize was applied first in the
original version of the TF OD API models, but in the recent versions it is applied after.
:param to_float: the Cast node which converts the input tensor to Float
:param mul: the Mul node (can be None)
:param sub: the Sub node
:return: the result of the check
"""
assert sub is not None, 'The Sub node should not be None. Check the caller function.'
if mul is not None:
return any([port.node.id == mul.id for port in to_float.out_port(0).get_destinations()])
else:
return any([port.node.id == sub.id for port in to_float.out_port(0).get_destinations()])
def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
argv = graph.graph['cmd_params']
layout = graph.graph['layout']
@ -494,14 +511,14 @@ class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSu
pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
sub_node = match.output_node(0)[0]
if not sub_node.has('op') or sub_node.op != 'Sub':
if sub_node.soft_get('op') != 'Sub':
raise Error('The output op of the Preprocessor sub-graph is not of type "Sub". Looks like the topology is '
'not created with TensorFlow Object Detection API.')
mul_node = None
if sub_node.in_node(0).has('op') and sub_node.in_node(0).op == 'Mul':
if sub_node.in_port(0).get_source().node.soft_get('op') == 'Mul':
log.info('There is image scaling node in the Preprocessor block.')
mul_node = sub_node.in_node(0)
mul_node = sub_node.in_port(0).get_source().node
initial_input_node_name = 'image_tensor'
if initial_input_node_name not in graph.nodes():
@ -521,16 +538,25 @@ class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSu
graph.graph['preprocessed_image_height'] = placeholder_node.shape[get_height_dim(layout, 4)]
graph.graph['preprocessed_image_width'] = placeholder_node.shape[get_width_dim(layout, 4)]
to_float_node = placeholder_node.out_node(0)
if not to_float_node.has('op') or to_float_node.op != 'Cast':
raise Error('The output of the node "{}" is not Cast operation. Cannot apply replacer.'.format(
to_float_node = placeholder_node.out_port(0).get_destination().node
if to_float_node.soft_get('op') != 'Cast':
raise Error('The output of the node "{}" is not Cast operation. Cannot apply transformation.'.format(
initial_input_node_name))
# connect to_float_node directly with node performing scale on mean value subtraction
if mul_node is None:
graph.create_edge(to_float_node, sub_node, 0, 0)
if self.is_preprocessing_applied_before_resize(to_float_node, mul_node, sub_node):
# connect sub node directly to nodes which consume resized image
resize_output_node_id = 'Preprocessor/map/TensorArrayStack/TensorArrayGatherV3'
if resize_output_node_id not in graph.nodes:
raise Error('There is no expected node "{}" in the graph.'.format(resize_output_node_id))
resize_output = Node(graph, resize_output_node_id)
for dst_port in resize_output.out_port(0).get_destinations():
dst_port.get_connection().set_source(sub_node.out_port(0))
else:
graph.create_edge(to_float_node, mul_node, 0, 1)
# connect to_float_node directly with node performing scale on mean value subtraction
if mul_node is None:
to_float_node.out_port(0).connect(sub_node.in_port(0))
else:
to_float_node.out_port(0).connect(mul_node.in_port(1))
print('The Preprocessor block has been removed. Only nodes performing mean value subtraction and scaling (if'
' applicable) are kept.')