Added support for a new version of the TF OD API pre-processing part (#3063)
* Added support for a new version of the TF OD API pre-processing part of the mode * Get rid of legacy API usage * Fix comment and added assert * Wording
This commit is contained in:
parent
6b09d5769f
commit
f633f0035c
@ -486,6 +486,23 @@ class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSu
|
||||
new_nodes_to_remove.remove(node_to_keep)
|
||||
return new_nodes_to_remove
|
||||
|
||||
def is_preprocessing_applied_before_resize(self, to_float: Node, mul: Node, sub: Node):
|
||||
"""
|
||||
The function checks if the output of 'to_float' operation is consumed by 'mul' or 'sub'. If this is true then
|
||||
the pre-processing (mean/scale) is applied before the image resize. The image resize was applied first in the
|
||||
original version of the TF OD API models, but in the recent versions it is applied after.
|
||||
|
||||
:param to_float: the Cast node which converts the input tensor to Float
|
||||
:param mul: the Mul node (can be None)
|
||||
:param sub: the Sub node
|
||||
:return: the result of the check
|
||||
"""
|
||||
assert sub is not None, 'The Sub node should not be None. Check the caller function.'
|
||||
if mul is not None:
|
||||
return any([port.node.id == mul.id for port in to_float.out_port(0).get_destinations()])
|
||||
else:
|
||||
return any([port.node.id == sub.id for port in to_float.out_port(0).get_destinations()])
|
||||
|
||||
def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
|
||||
argv = graph.graph['cmd_params']
|
||||
layout = graph.graph['layout']
|
||||
@ -494,14 +511,14 @@ class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSu
|
||||
pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
|
||||
|
||||
sub_node = match.output_node(0)[0]
|
||||
if not sub_node.has('op') or sub_node.op != 'Sub':
|
||||
if sub_node.soft_get('op') != 'Sub':
|
||||
raise Error('The output op of the Preprocessor sub-graph is not of type "Sub". Looks like the topology is '
|
||||
'not created with TensorFlow Object Detection API.')
|
||||
|
||||
mul_node = None
|
||||
if sub_node.in_node(0).has('op') and sub_node.in_node(0).op == 'Mul':
|
||||
if sub_node.in_port(0).get_source().node.soft_get('op') == 'Mul':
|
||||
log.info('There is image scaling node in the Preprocessor block.')
|
||||
mul_node = sub_node.in_node(0)
|
||||
mul_node = sub_node.in_port(0).get_source().node
|
||||
|
||||
initial_input_node_name = 'image_tensor'
|
||||
if initial_input_node_name not in graph.nodes():
|
||||
@ -521,16 +538,25 @@ class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSu
|
||||
graph.graph['preprocessed_image_height'] = placeholder_node.shape[get_height_dim(layout, 4)]
|
||||
graph.graph['preprocessed_image_width'] = placeholder_node.shape[get_width_dim(layout, 4)]
|
||||
|
||||
to_float_node = placeholder_node.out_node(0)
|
||||
if not to_float_node.has('op') or to_float_node.op != 'Cast':
|
||||
raise Error('The output of the node "{}" is not Cast operation. Cannot apply replacer.'.format(
|
||||
to_float_node = placeholder_node.out_port(0).get_destination().node
|
||||
if to_float_node.soft_get('op') != 'Cast':
|
||||
raise Error('The output of the node "{}" is not Cast operation. Cannot apply transformation.'.format(
|
||||
initial_input_node_name))
|
||||
|
||||
# connect to_float_node directly with node performing scale on mean value subtraction
|
||||
if mul_node is None:
|
||||
graph.create_edge(to_float_node, sub_node, 0, 0)
|
||||
if self.is_preprocessing_applied_before_resize(to_float_node, mul_node, sub_node):
|
||||
# connect sub node directly to nodes which consume resized image
|
||||
resize_output_node_id = 'Preprocessor/map/TensorArrayStack/TensorArrayGatherV3'
|
||||
if resize_output_node_id not in graph.nodes:
|
||||
raise Error('There is no expected node "{}" in the graph.'.format(resize_output_node_id))
|
||||
resize_output = Node(graph, resize_output_node_id)
|
||||
for dst_port in resize_output.out_port(0).get_destinations():
|
||||
dst_port.get_connection().set_source(sub_node.out_port(0))
|
||||
else:
|
||||
graph.create_edge(to_float_node, mul_node, 0, 1)
|
||||
# connect to_float_node directly with node performing scale on mean value subtraction
|
||||
if mul_node is None:
|
||||
to_float_node.out_port(0).connect(sub_node.in_port(0))
|
||||
else:
|
||||
to_float_node.out_port(0).connect(mul_node.in_port(1))
|
||||
|
||||
print('The Preprocessor block has been removed. Only nodes performing mean value subtraction and scaling (if'
|
||||
' applicable) are kept.')
|
||||
|
Loading…
Reference in New Issue
Block a user