Enable TF 2.0 Object Detection API models (#3556)
* Config for TF 2.0 Faster R-CNN models, refactored subgraph_between_nodes to use graph API * Added support for new type of Preprocessing block in the TF 2.0 OD API models. Various fixes to enable the Faster R-CNN ResNet 50 * Updated text comments * Fixed sub_graph_between_nodes for TensorIteratorMerge. Added support for the TF 2.X EfficientDet models (not yet reshape-able) * Fixed unit tests * Fixed regression for TF 1.X OD API SSD model, enabled TF 2.0 OD API SSD models * Code clean up * Switched TF 2.0 OD API Faster R-CNN to preprocessor replacement type 2 * Refactored ObjectDetectionAPIPreprocessorReplacement and ObjectDetectionAPIPreprocessor2Replacement * Fixed bug in the Div transformation to Mul when input is integer. * Added support for the TF 2.0 OD API Mask R-CNN * Added unit tests for Div operation. Updated incorrectly modified mask_rcnn_support_api_v1.14.json * Updated document with list of supported configuration files for TF OD API models * Review comments * Added tests for control flow edges for the sub_graph_between_nodes function * Two more tests
This commit is contained in:
parent
129a6553fa
commit
c6bfac6e05
@ -8,29 +8,33 @@
|
||||
|
||||
With 2018 R3 release, the Model Optimizer introduces a new approach to convert models created using the TensorFlow\* Object Detection API. Compared with the previous approach, the new process produces inference results with higher accuracy and does not require modifying any configuration files and providing intricate command line parameters.
|
||||
|
||||
You can download TensorFlow\* Object Detection API models from the <a href="https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md">Object Detection Model Zoo</a>.
|
||||
You can download TensorFlow\* Object Detection API models from the <a href="https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md">TensorFlow 1 Detection Model Zoo</a>
|
||||
or <a href="https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md">TensorFlow 2 Detection Model Zoo</a>.
|
||||
|
||||
<strong>NOTE</strong>: Before converting, make sure you have configured the Model Optimizer. For configuration steps, refer to [Configuring the Model Optimizer](../../Config_Model_Optimizer.md).
|
||||
|
||||
To convert a TensorFlow\* Object Detection API model, go to the `<INSTALL_DIR>/deployment_tools/model_optimizer` directory and run the `mo_tf.py` script with the following required parameters:
|
||||
|
||||
* `--input_model <path_to_frozen.pb>` --- File with a pre-trained model (binary or text .pb file after freezing)
|
||||
* `--input_model <path_to_frozen.pb>` --- File with a pre-trained model (binary or text .pb file after freezing) OR `--saved_model_dir <path_to_saved_model>` for the TensorFlow\* 2 models
|
||||
* `--transformations_config <path_to_subgraph_replacement_configuration_file.json>` --- A subgraph replacement configuration file with transformations description. For the models downloaded from the TensorFlow\* Object Detection API zoo, you can find the configuration files in the `<INSTALL_DIR>/deployment_tools/model_optimizer/extensions/front/tf` directory. Use:
|
||||
* `ssd_v2_support.json` --- for frozen SSD topologies from the models zoo version up to 1.13.X inclusively
|
||||
* `ssd_support_api_v.1.14.json` --- for frozen SSD topologies trained manually using the TensorFlow* Object Detection API version 1.14 up to 1.14.X inclusively
|
||||
* `ssd_support_api_v.1.15.json` --- for frozen SSD topologies trained manually using the TensorFlow* Object Detection API version 1.15 or higher
|
||||
* `ssd_support_api_v.1.14.json` --- for frozen SSD topologies trained using the TensorFlow\* Object Detection API version 1.14 up to 1.14.X inclusively
|
||||
* `ssd_support_api_v.1.15.json` --- for frozen SSD topologies trained using the TensorFlow\* Object Detection API version 1.15 up to 2.0
|
||||
* `ssd_support_api_v.2.0.json` --- for frozen SSD topologies trained using the TensorFlow\* Object Detection API version 2.0 or higher
|
||||
* `faster_rcnn_support.json` --- for frozen Faster R-CNN topologies from the models zoo
|
||||
* `faster_rcnn_support_api_v1.7.json` --- for Faster R-CNN topologies trained manually using the TensorFlow* Object Detection API version 1.7.0 up to 1.9.X inclusively
|
||||
* `faster_rcnn_support_api_v1.10.json` --- for Faster R-CNN topologies trained manually using the TensorFlow* Object Detection API version 1.10.0 up to 1.12.X inclusively
|
||||
* `faster_rcnn_support_api_v1.13.json` --- for Faster R-CNN topologies trained manually using the TensorFlow* Object Detection API version 1.13.X
|
||||
* `faster_rcnn_support_api_v1.14.json` --- for Faster R-CNN topologies trained manually using the TensorFlow* Object Detection API version 1.14.0 up to 1.14.X inclusively
|
||||
* `faster_rcnn_support_api_v1.15.json` --- for Faster R-CNN topologies trained manually using the TensorFlow* Object Detection API version 1.15.0 or higher
|
||||
* `faster_rcnn_support_api_v1.7.json` --- for Faster R-CNN topologies trained using the TensorFlow\* Object Detection API version 1.7.0 up to 1.9.X inclusively
|
||||
* `faster_rcnn_support_api_v1.10.json` --- for Faster R-CNN topologies trained using the TensorFlow\* Object Detection API version 1.10.0 up to 1.12.X inclusively
|
||||
* `faster_rcnn_support_api_v1.13.json` --- for Faster R-CNN topologies trained using the TensorFlow\* Object Detection API version 1.13.X
|
||||
* `faster_rcnn_support_api_v1.14.json` --- for Faster R-CNN topologies trained using the TensorFlow\* Object Detection API version 1.14.0 up to 1.14.X inclusively
|
||||
* `faster_rcnn_support_api_v1.15.json` --- for Faster R-CNN topologies trained using the TensorFlow\* Object Detection API version 1.15.0 up to 2.0
|
||||
* `faster_rcnn_support_api_v2.0.json` --- for Faster R-CNN topologies trained using the TensorFlow\* Object Detection API version 2.0 or higher
|
||||
* `mask_rcnn_support.json` --- for frozen Mask R-CNN topologies from the models zoo
|
||||
* `mask_rcnn_support_api_v1.7.json` --- for Mask R-CNN topologies trained manually using the TensorFlow* Object Detection API version 1.7.0 up to 1.9.X inclusively
|
||||
* `mask_rcnn_support_api_v1.11.json` --- for Mask R-CNN topologies trained manually using the TensorFlow* Object Detection API version 1.11.0 up to 1.12.X inclusively
|
||||
* `mask_rcnn_support_api_v1.13.json` --- for Mask R-CNN topologies trained manually using the TensorFlow* Object Detection API version 1.13.0 up to 1.13.X inclusively
|
||||
* `mask_rcnn_support_api_v1.14.json` --- for Mask R-CNN topologies trained manually using the TensorFlow* Object Detection API version 1.14.0 up to 1.14.X inclusively
|
||||
* `mask_rcnn_support_api_v1.15.json` --- for Mask R-CNN topologies trained manually using the TensorFlow* Object Detection API version 1.15.0 or higher
|
||||
* `mask_rcnn_support_api_v1.7.json` --- for Mask R-CNN topologies trained using the TensorFlow\* Object Detection API version 1.7.0 up to 1.9.X inclusively
|
||||
* `mask_rcnn_support_api_v1.11.json` --- for Mask R-CNN topologies trained using the TensorFlow\* Object Detection API version 1.11.0 up to 1.12.X inclusively
|
||||
* `mask_rcnn_support_api_v1.13.json` --- for Mask R-CNN topologies trained using the TensorFlow\* Object Detection API version 1.13.0 up to 1.13.X inclusively
|
||||
* `mask_rcnn_support_api_v1.14.json` --- for Mask R-CNN topologies trained using the TensorFlow\* Object Detection API version 1.14.0 up to 1.14.X inclusively
|
||||
* `mask_rcnn_support_api_v1.15.json` --- for Mask R-CNN topologies trained using the TensorFlow\* Object Detection API version 1.15.0 up to 2.0
|
||||
* `mask_rcnn_support_api_v2.0.json` --- for Mask R-CNN topologies trained using the TensorFlow\* Object Detection API version 2.0 or higher
|
||||
* `rfcn_support.json` --- for the frozen RFCN topology from the models zoo frozen with TensorFlow\* version 1.9.0 or lower.
|
||||
* `rfcn_support_api_v1.10.json` --- for the frozen RFCN topology from the models zoo frozen with TensorFlow\* version 1.10.0 up to 1.12.X inclusively
|
||||
* `rfcn_support_api_v1.13.json` --- for the frozen RFCN topology from the models zoo frozen with TensorFlow\* version 1.13.X.
|
||||
|
@ -373,6 +373,7 @@ extensions/front/tf/CTCLossReplacement.py
|
||||
extensions/front/tf/cumsum_ext.py
|
||||
extensions/front/tf/deconv_ext.py
|
||||
extensions/front/tf/depth_to_space.py
|
||||
extensions/front/tf/efficient_det_support_api_v2.0.json
|
||||
extensions/front/tf/elementwise_ext.py
|
||||
extensions/front/tf/embedding_segments_sum.py
|
||||
extensions/front/tf/expand_dims_ext.py
|
||||
@ -386,6 +387,7 @@ extensions/front/tf/faster_rcnn_support_api_v1.13.json
|
||||
extensions/front/tf/faster_rcnn_support_api_v1.14.json
|
||||
extensions/front/tf/faster_rcnn_support_api_v1.15.json
|
||||
extensions/front/tf/faster_rcnn_support_api_v1.7.json
|
||||
extensions/front/tf/faster_rcnn_support_api_v2.0.json
|
||||
extensions/front/tf/fifo_queue_v2_ext.py
|
||||
extensions/front/tf/fifo_replacer.py
|
||||
extensions/front/tf/fill_ext.py
|
||||
@ -410,6 +412,7 @@ extensions/front/tf/mask_rcnn_support_api_v1.13.json
|
||||
extensions/front/tf/mask_rcnn_support_api_v1.14.json
|
||||
extensions/front/tf/mask_rcnn_support_api_v1.15.json
|
||||
extensions/front/tf/mask_rcnn_support_api_v1.7.json
|
||||
extensions/front/tf/mask_rcnn_support_api_v2.0.json
|
||||
extensions/front/tf/matmul_ext.py
|
||||
extensions/front/tf/mvn.py
|
||||
extensions/front/tf/mvn_unrolled.py
|
||||
@ -457,6 +460,7 @@ extensions/front/tf/split_ext.py
|
||||
extensions/front/tf/ssd_support.json
|
||||
extensions/front/tf/ssd_support_api_v1.14.json
|
||||
extensions/front/tf/ssd_support_api_v1.15.json
|
||||
extensions/front/tf/ssd_support_api_v2.0.json
|
||||
extensions/front/tf/ssd_toolbox_detection_output.json
|
||||
extensions/front/tf/ssd_toolbox_multihead_detection_output.json
|
||||
extensions/front/tf/ssd_v2_support.json
|
||||
@ -1011,4 +1015,4 @@ requirements_kaldi.txt
|
||||
requirements_mxnet.txt
|
||||
requirements_onnx.txt
|
||||
requirements_tf.txt
|
||||
requirements_tf2.txt
|
||||
requirements_tf2.txt
|
||||
|
@ -33,6 +33,11 @@ class Div(FrontReplacementPattern):
|
||||
if div.in_port(0).data.get_value() is not None and div.in_port(1).data.get_value() is not None:
|
||||
return
|
||||
|
||||
# cannot replace Div with Mul when the divisor is integer because the reciprocal number will be 0
|
||||
value = div.in_port(1).data.get_value()
|
||||
if value is not None and type(value.item(0)) == int:
|
||||
return
|
||||
|
||||
graph = div.graph
|
||||
name = div.soft_get('name', div.id)
|
||||
|
||||
|
@ -78,3 +78,21 @@ class TestDiv(unittest.TestCase):
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
self.assertTrue(graph.node[graph.get_nodes_with_attributes(type='Multiply')[0]]['name'] == 'my_div')
|
||||
|
||||
def test_div_with_integer(self):
|
||||
# Test where transformation should not be applied because the divisor is integer
|
||||
graph = build_graph({
|
||||
**regular_op_with_shaped_data('parameter', [1, 227, 227, 3], {'type': 'Parameter', 'data_type': np.int32}),
|
||||
**valued_const_with_data('const', np.array([-1.], dtype=np.int32)),
|
||||
**regular_op_with_shaped_data('div', None, {'op': 'Div', 'type': 'Divide', 'name': 'my_div'}),
|
||||
**result()},
|
||||
[
|
||||
*connect('parameter:0', '0:div'),
|
||||
*connect_data('const:0', '1:div'),
|
||||
*connect('div', 'output'),
|
||||
])
|
||||
graph_ref = graph.copy()
|
||||
Div().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
@ -240,7 +240,8 @@ def _create_prior_boxes_node(graph: Graph, pipeline_config: PipelineConfig):
|
||||
# connect the PriorBoxClustered node with the "Cast" node of the Placeholder node because the pass that removes
|
||||
# Cast operations is executed in the middle phase and it will fail when there are several consumers of the
|
||||
# Placeholder
|
||||
prior_box_node = prior_box_op.create_node([ssd_head_node, Node(graph, 'image_tensor').out_node(0)],
|
||||
input_node_name = 'image_tensor' if 'image_tensor' in graph.nodes else 'input_tensor'
|
||||
prior_box_node = prior_box_op.create_node([ssd_head_node, Node(graph, input_node_name).out_node(0)],
|
||||
{'name': 'PriorBoxClustered_{}'.format(ssd_head_ind)})
|
||||
prior_box_nodes.append(prior_box_node)
|
||||
if len(prior_box_nodes) == 1:
|
||||
@ -399,16 +400,17 @@ def calculate_placeholder_spatial_shape(graph: Graph, match: SubgraphMatch, pipe
|
||||
width = None
|
||||
user_shapes = graph.graph['user_shapes']
|
||||
|
||||
if 'preprocessed_image_height' in match.custom_replacement_desc.custom_attributes or 'preprocessed_image_width' in \
|
||||
match.custom_replacement_desc.custom_attributes:
|
||||
if match and ('preprocessed_image_height' in match.custom_replacement_desc.custom_attributes or
|
||||
'preprocessed_image_width' in match.custom_replacement_desc.custom_attributes):
|
||||
log.error('The "preprocessed_image_height" or "preprocessed_image_width" is specified in the sub-graph '
|
||||
'replacement configuration file but they are ignored. Please, specify desired input shape using the '
|
||||
'"--input_shape" command line parameter.', extra={'is_warning': True})
|
||||
|
||||
user_defined_height = None
|
||||
user_defined_width = None
|
||||
if user_shapes and 'image_tensor' in user_shapes and user_shapes['image_tensor']:
|
||||
user_defined_shape = user_shapes['image_tensor'][0]['shape']
|
||||
input_name = 'input_tensor' if 'input_tensor' in graph.nodes else 'image_tensor'
|
||||
if user_shapes and input_name in user_shapes and user_shapes[input_name]:
|
||||
user_defined_shape = user_shapes[input_name][0]['shape']
|
||||
if user_defined_shape is not None:
|
||||
user_defined_height = user_defined_shape[1]
|
||||
user_defined_width = user_defined_shape[2]
|
||||
@ -464,6 +466,43 @@ def calculate_placeholder_spatial_shape(graph: Graph, match: SubgraphMatch, pipe
|
||||
return height, width
|
||||
|
||||
|
||||
def update_parameter_shape(graph: Graph, match: [SubgraphMatch, None]):
|
||||
"""
|
||||
Updates the shape of the model Parameter node based on the user provided input shape or values provided in the
|
||||
pipeline.config configuration file used for model training.
|
||||
:param graph: model graph
|
||||
:param match: Match object with information abouot matched sub-graph
|
||||
:return: tupe with input node names and Parameter Node
|
||||
"""
|
||||
argv = graph.graph['cmd_params']
|
||||
if argv.tensorflow_object_detection_api_pipeline_config is None:
|
||||
raise Error(missing_param_error)
|
||||
|
||||
pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
|
||||
if argv.tensorflow_object_detection_api_pipeline_config is None:
|
||||
raise Error(missing_param_error)
|
||||
|
||||
initial_input_node_name = 'input_tensor' if 'input_tensor' in graph.nodes else 'image_tensor'
|
||||
if initial_input_node_name not in graph.nodes():
|
||||
raise Error('Input node "{}" of the graph is not found. Do not run the Model Optimizer with '
|
||||
'"--input" command line parameter.'.format(initial_input_node_name))
|
||||
parameter_node = Node(graph, initial_input_node_name)
|
||||
|
||||
# set default value of the batch size to 1 if user didn't specify batch size and input shape
|
||||
layout = graph.graph['layout']
|
||||
batch_dim = get_batch_dim(layout, 4)
|
||||
if argv.batch is None and parameter_node.shape[batch_dim] == -1:
|
||||
parameter_node.shape[batch_dim] = 1
|
||||
height, width = calculate_placeholder_spatial_shape(graph, match, pipeline_config)
|
||||
parameter_node.shape[get_height_dim(layout, 4)] = height
|
||||
parameter_node.shape[get_width_dim(layout, 4)] = width
|
||||
|
||||
# save the pre-processed image spatial sizes to be used in the other replacers
|
||||
graph.graph['preprocessed_image_height'] = parameter_node.shape[get_height_dim(layout, 4)]
|
||||
graph.graph['preprocessed_image_width'] = parameter_node.shape[get_width_dim(layout, 4)]
|
||||
return initial_input_node_name, parameter_node
|
||||
|
||||
|
||||
class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSubGraph):
|
||||
"""
|
||||
The class replaces the "Preprocessor" block resizing input image and applying mean/scale values. Only nodes related
|
||||
@ -504,12 +543,6 @@ class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSu
|
||||
return any([port.node.id == sub.id for port in to_float.out_port(0).get_destinations()])
|
||||
|
||||
def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
|
||||
argv = graph.graph['cmd_params']
|
||||
layout = graph.graph['layout']
|
||||
if argv.tensorflow_object_detection_api_pipeline_config is None:
|
||||
raise Error(missing_param_error)
|
||||
pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
|
||||
|
||||
sub_node = match.output_node(0)[0]
|
||||
if sub_node.soft_get('op') != 'Sub':
|
||||
raise Error('The output op of the Preprocessor sub-graph is not of type "Sub". Looks like the topology is '
|
||||
@ -520,23 +553,7 @@ class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSu
|
||||
log.info('There is image scaling node in the Preprocessor block.')
|
||||
mul_node = sub_node.in_port(0).get_source().node
|
||||
|
||||
initial_input_node_name = 'image_tensor'
|
||||
if initial_input_node_name not in graph.nodes():
|
||||
raise Error('Input node "{}" of the graph is not found. Do not run the Model Optimizer with '
|
||||
'"--input" command line parameter.'.format(initial_input_node_name))
|
||||
placeholder_node = Node(graph, initial_input_node_name)
|
||||
|
||||
# set default value of the batch size to 1 if user didn't specify batch size and input shape
|
||||
batch_dim = get_batch_dim(layout, 4)
|
||||
if argv.batch is None and placeholder_node.shape[batch_dim] == -1:
|
||||
placeholder_node.shape[batch_dim] = 1
|
||||
height, width = calculate_placeholder_spatial_shape(graph, match, pipeline_config)
|
||||
placeholder_node.shape[get_height_dim(layout, 4)] = height
|
||||
placeholder_node.shape[get_width_dim(layout, 4)] = width
|
||||
|
||||
# save the pre-processed image spatial sizes to be used in the other replacers
|
||||
graph.graph['preprocessed_image_height'] = placeholder_node.shape[get_height_dim(layout, 4)]
|
||||
graph.graph['preprocessed_image_width'] = placeholder_node.shape[get_width_dim(layout, 4)]
|
||||
initial_input_node_name, placeholder_node = update_parameter_shape(graph, match)
|
||||
|
||||
to_float_node = placeholder_node.out_port(0).get_destination().node
|
||||
if to_float_node.soft_get('op') != 'Cast':
|
||||
@ -563,6 +580,41 @@ class ObjectDetectionAPIPreprocessorReplacement(FrontReplacementFromConfigFileSu
|
||||
return {}
|
||||
|
||||
|
||||
class ObjectDetectionAPIPreprocessor2Replacement(FrontReplacementFromConfigFileGeneral):
|
||||
"""
|
||||
The class replaces the "Preprocessor" block resizing input image and applying mean/scale values. Only nodes related
|
||||
to applying mean/scaling values are kept. The transformation is used for TensorFlow 2.X models.
|
||||
"""
|
||||
replacement_id = 'ObjectDetectionAPIPreprocessor2Replacement'
|
||||
|
||||
def run_before(self):
|
||||
# PadTFToPad inserts Transpose ops for Pad ops inside the sub-graph corresponding to DetectionOutput.
|
||||
# But the inputs corresponding to padding values is re-used as inputs for newly created Pad node. This input
|
||||
# is removed during removing nodes from the DO sub-graph so the first input to Transpose is missing which
|
||||
# results in TransposeOrderNormalizer transformation failure.
|
||||
return [Pack, TransposeOrderNormalizer, PadTFToPad]
|
||||
|
||||
def transform_graph(self, graph: Graph, replacement_descriptions: dict):
|
||||
update_parameter_shape(graph, None)
|
||||
|
||||
start_nodes = replacement_descriptions['start_nodes']
|
||||
end_nodes = replacement_descriptions['end_nodes']
|
||||
|
||||
assert len(start_nodes) >= 1
|
||||
assert start_nodes[0] in graph.nodes
|
||||
input_node = Node(graph, start_nodes[0])
|
||||
|
||||
assert len(end_nodes) >= 1
|
||||
assert end_nodes[0] in graph.nodes
|
||||
output_node = Node(graph, end_nodes[0])
|
||||
|
||||
output_node.out_port(0).get_connection().set_source(input_node.in_port(0).get_source())
|
||||
input_node.in_port(0).disconnect()
|
||||
|
||||
print('The Preprocessor block has been removed. Only nodes performing mean value subtraction and scaling (if'
|
||||
' applicable) are kept.')
|
||||
|
||||
|
||||
class ObjectDetectionAPIDetectionOutputReplacement(FrontReplacementFromConfigFileSubGraph):
|
||||
"""
|
||||
Replaces the sub-graph that is equal to the DetectionOutput layer from Inference Engine. This replacer is used for
|
||||
@ -606,6 +658,7 @@ class ObjectDetectionAPIDetectionOutputReplacement(FrontReplacementFromConfigFil
|
||||
if argv.tensorflow_object_detection_api_pipeline_config is None:
|
||||
raise Error(missing_param_error)
|
||||
pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
|
||||
custom_attributes = match.custom_replacement_desc.custom_attributes
|
||||
|
||||
num_classes = _value_or_raise(match, pipeline_config, 'num_classes')
|
||||
max_proposals = _value_or_raise(match, pipeline_config, 'first_stage_max_proposals')
|
||||
@ -630,18 +683,27 @@ class ObjectDetectionAPIDetectionOutputReplacement(FrontReplacementFromConfigFil
|
||||
current_node = skip_nodes_by_condition(match.single_input_node(0)[0].in_node(0),
|
||||
lambda x: x['kind'] == 'op' and x.has_and_set('reinterp_shape'))
|
||||
|
||||
reshape_loc_node = create_op_node_with_second_input(graph, Reshape, int64_array([-1, num_classes, 1, 4]),
|
||||
dict(name='reshape_loc'), current_node)
|
||||
share_box_across_classes = _value_or_raise(match, pipeline_config, 'share_box_across_classes')
|
||||
background_label_id = int(custom_attributes.get('background_label_id', 0))
|
||||
if share_box_across_classes:
|
||||
reshape_loc_node = create_op_node_with_second_input(graph, Reshape, int64_array([-1, 1, 1, 4]),
|
||||
dict(name='reshape_loc'), current_node)
|
||||
else:
|
||||
reshape_loc_node = create_op_node_with_second_input(graph, Reshape, int64_array([-1, num_classes, 1, 4]),
|
||||
dict(name='reshape_loc'), current_node)
|
||||
mark_as_correct_data_layout(reshape_loc_node)
|
||||
|
||||
# constant node with variances
|
||||
variances_const_op = Const(graph, dict(value=_variance_from_pipeline_config(pipeline_config)))
|
||||
variances_const_node = variances_const_op.create_node([])
|
||||
|
||||
# TF produces locations tensor without boxes for background.
|
||||
# Inference Engine DetectionOutput layer requires background boxes so we generate them
|
||||
loc_node = add_fake_background_loc(graph, reshape_loc_node)
|
||||
PermuteAttrs.set_permutation(reshape_loc_node, loc_node, None)
|
||||
if share_box_across_classes:
|
||||
loc_node = reshape_loc_node
|
||||
else:
|
||||
# TF produces locations tensor without boxes for background.
|
||||
# Inference Engine DetectionOutput layer requires background boxes so we generate them
|
||||
loc_node = add_fake_background_loc(graph, reshape_loc_node)
|
||||
PermuteAttrs.set_permutation(reshape_loc_node, loc_node, None)
|
||||
|
||||
# reshape locations tensor to 2D so it could be passed to Eltwise which will be converted to ScaleShift
|
||||
reshape_loc_2d_node = create_op_node_with_second_input(graph, Reshape, int64_array([-1, 4]),
|
||||
@ -659,7 +721,6 @@ class ObjectDetectionAPIDetectionOutputReplacement(FrontReplacementFromConfigFil
|
||||
# calculate the second dimension so the batch value will be deduced from it with help of "-1".
|
||||
reshape_loc_do_op = Reshape(graph, dict(name='do_reshape_locs'))
|
||||
|
||||
custom_attributes = match.custom_replacement_desc.custom_attributes
|
||||
coordinates_swap_method = 'add_convolution'
|
||||
if 'coordinates_swap_method' not in custom_attributes:
|
||||
log.error('The ObjectDetectionAPIDetectionOutputReplacement sub-graph replacement configuration file '
|
||||
@ -683,8 +744,12 @@ class ObjectDetectionAPIDetectionOutputReplacement(FrontReplacementFromConfigFil
|
||||
else:
|
||||
reshape_loc_do_node = reshape_loc_do_op.create_node([eltwise_locs_node])
|
||||
|
||||
reshape_loc_do_dims = Const(graph, {'value': int64_array([-1, (num_classes + 1) * max_proposals * 4]),
|
||||
'name': reshape_loc_do_node.name + '/Dim'}).create_node()
|
||||
if share_box_across_classes:
|
||||
reshape_loc_do_dims = Const(graph, {'value': int64_array([-1, max_proposals * 4]),
|
||||
'name': reshape_loc_do_node.name + '/Dim'}).create_node()
|
||||
else:
|
||||
reshape_loc_do_dims = Const(graph, {'value': int64_array([-1, (num_classes + 1) * max_proposals * 4]),
|
||||
'name': reshape_loc_do_node.name + '/Dim'}).create_node()
|
||||
reshape_loc_do_dims.out_port(0).connect(reshape_loc_do_node.in_port(1))
|
||||
|
||||
mark_as_correct_data_layout(reshape_loc_do_node)
|
||||
@ -716,7 +781,10 @@ class ObjectDetectionAPIDetectionOutputReplacement(FrontReplacementFromConfigFil
|
||||
|
||||
detection_output_node = detection_output_op.create_node(
|
||||
[reshape_loc_do_node, reshape_conf_node, reshape_priors_node],
|
||||
dict(name=detection_output_op.attrs['type'], share_location=0, variance_encoded_in_target=1,
|
||||
dict(name=detection_output_op.attrs['type'],
|
||||
share_location=int(share_box_across_classes),
|
||||
variance_encoded_in_target=1,
|
||||
background_label_id=background_label_id,
|
||||
code_type='caffe.PriorBoxParameter.CENTER_SIZE', pad_mode='caffe.ResizeParameter.CONSTANT',
|
||||
resize_mode='caffe.ResizeParameter.WARP',
|
||||
num_classes=num_classes,
|
||||
@ -730,12 +798,15 @@ class ObjectDetectionAPIDetectionOutputReplacement(FrontReplacementFromConfigFil
|
||||
if coordinates_swap_method == 'swap_weights':
|
||||
swap_weights_xy(graph, backward_bfs_for_operation(detection_output_node.in_node(0), ['MatMul', 'Conv2D']))
|
||||
|
||||
# when the use_matmul_crop_and_resize = True then the prior boxes were not swapped and we need to swap them from
|
||||
# YXYX to XYXY before passing to the DetectionOutput operation
|
||||
if pipeline_config.get_param('use_matmul_crop_and_resize'):
|
||||
insert_weights_swap_xy_sub_graph(graph, detection_output_node.in_port(2).get_connection())
|
||||
output_op = Result(graph, dict(name='do_OutputOp'))
|
||||
output_op.create_node([detection_output_node])
|
||||
|
||||
print('The graph output nodes "num_detections", "detection_boxes", "detection_classes", "detection_scores" '
|
||||
'have been replaced with a single layer of type "Detection Output". Refer to IR catalogue in the '
|
||||
'documentation for information about this layer.')
|
||||
print('The graph output nodes have been replaced with a single layer of type "DetectionOutput". Refer to the '
|
||||
'operation set specification documentation for more information about the operation.')
|
||||
|
||||
return {'detection_output_node': detection_output_node}
|
||||
|
||||
@ -818,17 +889,18 @@ class ObjectDetectionAPIMaskRCNNSigmoidReplacement(FrontReplacementFromConfigFil
|
||||
return [ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement]
|
||||
|
||||
def transform_graph(self, graph: Graph, replacement_descriptions):
|
||||
masks_node_prefix_name = replacement_descriptions.get('masks_node_prefix_name', 'SecondStageBoxPredictor')
|
||||
op_outputs = graph.get_op_nodes(op='Result')
|
||||
for op_output in op_outputs:
|
||||
last_node = op_output.in_port(0).get_source().node
|
||||
if last_node.name.startswith('SecondStageBoxPredictor'):
|
||||
if last_node.name.startswith(masks_node_prefix_name):
|
||||
sigmoid_node = Sigmoid(graph, dict(name='masks')).create_node()
|
||||
op_output.in_port(0).get_connection().insert_node(sigmoid_node)
|
||||
|
||||
print('The predicted masks are produced by the "masks" layer for each bounding box generated with a '
|
||||
'"detection_output" layer.\n Refer to IR catalogue in the documentation for information '
|
||||
'about the DetectionOutput layer and Inference Engine documentation about output data interpretation.\n'
|
||||
'The topology can be inferred using dedicated demo "mask_rcnn_demo".')
|
||||
'"detection_output" operation.\n Refer to operation specification in the documentation for information '
|
||||
'about the DetectionOutput operation output data interpretation.\n'
|
||||
'The model can be inferred using the dedicated demo "mask_rcnn_demo" from the OpenVINO Open Model Zoo.')
|
||||
|
||||
|
||||
class ObjectDetectionAPIProposalReplacement(FrontReplacementFromConfigFileSubGraph):
|
||||
@ -840,10 +912,10 @@ class ObjectDetectionAPIProposalReplacement(FrontReplacementFromConfigFileSubGra
|
||||
replacement_id = 'ObjectDetectionAPIProposalReplacement'
|
||||
|
||||
def run_after(self):
|
||||
return [ObjectDetectionAPIPreprocessorReplacement]
|
||||
return [ObjectDetectionAPIPreprocessorReplacement, ObjectDetectionAPIPreprocessor2Replacement]
|
||||
|
||||
def run_before(self):
|
||||
return [CropAndResizeReplacement, TransposeOrderNormalizer]
|
||||
return [CropAndResizeReplacement, TransposeOrderNormalizer, Pack]
|
||||
|
||||
def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
|
||||
return {match.output_node(0)[0].id: new_sub_graph['proposal_node'].id}
|
||||
@ -945,35 +1017,35 @@ class ObjectDetectionAPIProposalReplacement(FrontReplacementFromConfigFileSubGra
|
||||
proposal_node = proposal_op.create_node([reshape_permute_node, anchors_node, input_with_image_size_node],
|
||||
dict(name='proposals'))
|
||||
|
||||
if 'do_not_swap_proposals' in match.custom_replacement_desc.custom_attributes and \
|
||||
match.custom_replacement_desc.custom_attributes['do_not_swap_proposals']:
|
||||
swapped_proposals_node = proposal_node
|
||||
else:
|
||||
swapped_proposals_node = add_convolution_to_swap_xy_coordinates(graph, proposal_node, 5)
|
||||
# models with use_matmul_crop_and_resize = True should not swap order of elements (YX to XY) after the Proposal
|
||||
swap_proposals = not match.custom_replacement_desc.custom_attributes.get('do_not_swap_proposals', False) and \
|
||||
not pipeline_config.get_param('use_matmul_crop_and_resize')
|
||||
|
||||
if swap_proposals:
|
||||
proposal_node = add_convolution_to_swap_xy_coordinates(graph, proposal_node, 5)
|
||||
|
||||
proposal_reshape_2d_node = create_op_node_with_second_input(graph, Reshape, int64_array([-1, 5]),
|
||||
dict(name="reshape_swap_proposals_2d"),
|
||||
swapped_proposals_node)
|
||||
proposal_node)
|
||||
mark_input_as_in_correct_layout(proposal_reshape_2d_node, 0)
|
||||
|
||||
# feed the CropAndResize node with a correct boxes information produced with the Proposal layer
|
||||
# find the first CropAndResize node in the BFS order
|
||||
crop_and_resize_nodes_ids = [node_id for node_id in bfs_search(graph, [match.single_input_node(0)[0].id]) if
|
||||
graph.node[node_id]['op'] == 'CropAndResize']
|
||||
assert len(crop_and_resize_nodes_ids) != 0, "Didn't find any CropAndResize nodes in the graph."
|
||||
if 'do_not_swap_proposals' not in match.custom_replacement_desc.custom_attributes or not \
|
||||
match.custom_replacement_desc.custom_attributes['do_not_swap_proposals']:
|
||||
if len(crop_and_resize_nodes_ids) != 0 and swap_proposals:
|
||||
# feed the CropAndResize node with a correct boxes information produced with the Proposal layer
|
||||
# find the first CropAndResize node in the BFS order. This is needed in the case when we already swapped
|
||||
# box coordinates data after the Proposal node
|
||||
crop_and_resize_node = Node(graph, crop_and_resize_nodes_ids[0])
|
||||
# set a marker that the input with box coordinates has been pre-processed so the CropAndResizeReplacement
|
||||
# set a marker that an input with box coordinates has been pre-processed so the CropAndResizeReplacement
|
||||
# transform doesn't try to merge the second and the third inputs
|
||||
crop_and_resize_node['inputs_preprocessed'] = True
|
||||
graph.remove_edge(crop_and_resize_node.in_node(1).id, crop_and_resize_node.id)
|
||||
graph.create_edge(proposal_reshape_2d_node, crop_and_resize_node, out_port=0, in_port=1)
|
||||
crop_and_resize_node.in_port(1).disconnect()
|
||||
proposal_reshape_2d_node.out_port(0).connect(crop_and_resize_node.in_port(1))
|
||||
|
||||
tf_proposal_reshape_4d_node = create_op_node_with_second_input(graph, Reshape,
|
||||
int64_array([-1, 1, max_proposals, 5]),
|
||||
dict(name="reshape_proposal_4d"),
|
||||
swapped_proposals_node)
|
||||
proposal_node)
|
||||
|
||||
crop_op = Crop(graph, dict(axis=int64_array([3]), offset=int64_array([1]), dim=int64_array([4]),
|
||||
nchw_layout=True))
|
||||
@ -991,7 +1063,8 @@ class ObjectDetectionAPISSDPostprocessorReplacement(FrontReplacementFromConfigFi
|
||||
replacement_id = 'ObjectDetectionAPISSDPostprocessorReplacement'
|
||||
|
||||
def run_after(self):
|
||||
return [ObjectDetectionAPIPreprocessorReplacement, FakeQuantWithMinMaxVarsToQuantize]
|
||||
return [ObjectDetectionAPIPreprocessorReplacement, ObjectDetectionAPIPreprocessor2Replacement,
|
||||
FakeQuantWithMinMaxVarsToQuantize]
|
||||
|
||||
def run_before(self):
|
||||
return [StandaloneConstEraser, TransposeOrderNormalizer, TFSliceToSliceReplacer]
|
||||
@ -1006,10 +1079,12 @@ class ObjectDetectionAPISSDPostprocessorReplacement(FrontReplacementFromConfigFi
|
||||
if argv.tensorflow_object_detection_api_pipeline_config is None:
|
||||
raise Error(missing_param_error)
|
||||
pipeline_config = PipelineConfig(argv.tensorflow_object_detection_api_pipeline_config)
|
||||
num_classes = _value_or_raise(match, pipeline_config, 'num_classes')
|
||||
|
||||
has_background_class = _value_or_raise(match, pipeline_config, 'add_background_class')
|
||||
num_classes = _value_or_raise(match, pipeline_config, 'num_classes') + has_background_class
|
||||
|
||||
# reshapes confidences to 4D before applying activation function and do not convert from NHWC to NCHW this node
|
||||
expand_dims_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, 1, -1, num_classes + 1]),
|
||||
expand_dims_node = create_op_node_with_second_input(graph, Reshape, int64_array([0, 1, -1, num_classes]),
|
||||
{'name': 'do_ExpandDims_conf'})
|
||||
expand_dims_node.in_port(0).connect(match.input_nodes(1)[0][0].in_node(0).out_port(0))
|
||||
|
||||
@ -1040,7 +1115,7 @@ class ObjectDetectionAPISSDPostprocessorReplacement(FrontReplacementFromConfigFi
|
||||
(pipeline_config.get_param('ssd_anchor_generator_num_layers') is not None or
|
||||
pipeline_config.get_param('multiscale_anchor_generator_min_level') is not None):
|
||||
# change the Reshape operations with hardcoded number of output elements of the convolution nodes to be
|
||||
# reshapable
|
||||
# reshape-able
|
||||
_relax_reshape_nodes(graph, pipeline_config)
|
||||
|
||||
# create PriorBoxClustered nodes instead of a constant value with prior boxes so the model could be reshaped
|
||||
@ -1120,7 +1195,8 @@ class ObjectDetectionAPIOutputReplacement(FrontReplacementFromConfigFileGeneral)
|
||||
replacement_id = 'ObjectDetectionAPIOutputReplacement'
|
||||
|
||||
def run_before(self):
|
||||
return [ObjectDetectionAPIPreprocessorReplacement, TransposeOrderNormalizer]
|
||||
return [ObjectDetectionAPIPreprocessorReplacement, ObjectDetectionAPIPreprocessor2Replacement,
|
||||
TransposeOrderNormalizer]
|
||||
|
||||
def transform_graph(self, graph: Graph, replacement_descriptions: dict):
|
||||
if graph.graph['cmd_params'].output is not None:
|
||||
@ -1215,7 +1291,8 @@ class ObjectDetectionAPIConstValueOverride(FrontReplacementFromConfigFileGeneral
|
||||
replacement_id = 'ObjectDetectionAPIConstValueOverride'
|
||||
|
||||
def run_before(self):
|
||||
return [ObjectDetectionAPIPreprocessorReplacement, TransposeOrderNormalizer]
|
||||
return [ObjectDetectionAPIPreprocessorReplacement, ObjectDetectionAPIPreprocessor2Replacement,
|
||||
TransposeOrderNormalizer]
|
||||
|
||||
def transform_graph(self, graph: Graph, replacement_descriptions: dict):
|
||||
argv = graph.graph['cmd_params']
|
||||
|
@ -0,0 +1,50 @@
|
||||
[
|
||||
{
|
||||
"custom_attributes": {
|
||||
"start_nodes": ["StatefulPartitionedCall/Preprocessor/unstack"],
|
||||
"end_nodes": ["StatefulPartitionedCall/Preprocessor/stack",
|
||||
"StatefulPartitionedCall/Preprocessor/stack_1"]
|
||||
},
|
||||
"id": "ObjectDetectionAPIPreprocessor2Replacement",
|
||||
"match_kind": "general"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"code_type": "caffe.PriorBoxParameter.CENTER_SIZE",
|
||||
"pad_mode": "caffe.ResizeParameter.CONSTANT",
|
||||
"resize_mode": "caffe.ResizeParameter.WARP",
|
||||
"clip_before_nms": false,
|
||||
"clip_after_nms": true,
|
||||
"disable_prior_boxes_layers_generator": true
|
||||
},
|
||||
"id": "ObjectDetectionAPISSDPostprocessorReplacement",
|
||||
"include_inputs_to_sub_graph": true,
|
||||
"include_outputs_to_sub_graph": true,
|
||||
"instances": {
|
||||
"end_points": [
|
||||
"StatefulPartitionedCall/Identity",
|
||||
"StatefulPartitionedCall/Identity_1",
|
||||
"StatefulPartitionedCall/Identity_2",
|
||||
"StatefulPartitionedCall/Identity_3",
|
||||
"StatefulPartitionedCall/Identity_4",
|
||||
"StatefulPartitionedCall/Identity_5",
|
||||
"StatefulPartitionedCall/Identity_6",
|
||||
"StatefulPartitionedCall/Identity_7"
|
||||
],
|
||||
"start_points": [
|
||||
"StatefulPartitionedCall/Postprocessor/Reshape_1",
|
||||
"StatefulPartitionedCall/Postprocessor/scale_logits",
|
||||
"StatefulPartitionedCall/Postprocessor/Tile",
|
||||
"StatefulPartitionedCall/Postprocessor/Cast_1"
|
||||
]
|
||||
},
|
||||
"match_kind": "points"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"outputs": "StatefulPartitionedCall/Identity,StatefulPartitionedCall/Identity_1,StatefulPartitionedCall/Identity_2,StatefulPartitionedCall/Identity_3,StatefulPartitionedCall/Identity_4,StatefulPartitionedCall/Identity_5,StatefulPartitionedCall/Identity_6,StatefulPartitionedCall/Identity_7"
|
||||
},
|
||||
"id": "ObjectDetectionAPIOutputReplacement",
|
||||
"match_kind": "general"
|
||||
}
|
||||
]
|
@ -0,0 +1,82 @@
|
||||
[
|
||||
{
|
||||
"custom_attributes": {
|
||||
"start_nodes": ["StatefulPartitionedCall/Preprocessor/unstack"],
|
||||
"end_nodes": ["StatefulPartitionedCall/Preprocessor/stack",
|
||||
"StatefulPartitionedCall/Preprocessor/stack_1"]
|
||||
},
|
||||
"id": "ObjectDetectionAPIPreprocessor2Replacement",
|
||||
"match_kind": "general"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"clip_before_nms": false,
|
||||
"clip_after_nms": true
|
||||
},
|
||||
"id": "ObjectDetectionAPIProposalReplacement",
|
||||
"include_inputs_to_sub_graph": true,
|
||||
"include_outputs_to_sub_graph": true,
|
||||
"instances": {
|
||||
"end_points": [
|
||||
"StatefulPartitionedCall/stack_3",
|
||||
"StatefulPartitionedCall/BatchMultiClassNonMaxSuppression/stack_10",
|
||||
"StatefulPartitionedCall/Shape"
|
||||
],
|
||||
"start_points": [
|
||||
"StatefulPartitionedCall/concat/concat",
|
||||
"StatefulPartitionedCall/concat_1/concat",
|
||||
"StatefulPartitionedCall/GridAnchorGenerator/Identity",
|
||||
"StatefulPartitionedCall/Cast_1",
|
||||
"StatefulPartitionedCall/Cast_2",
|
||||
"StatefulPartitionedCall/Shape"
|
||||
]
|
||||
},
|
||||
"match_kind": "points"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"clip_before_nms": false,
|
||||
"clip_after_nms": true,
|
||||
"background_label_id": 0,
|
||||
"coordinates_swap_method": "swap_weights"
|
||||
},
|
||||
"id": "ObjectDetectionAPIDetectionOutputReplacement",
|
||||
"inputs": [
|
||||
[
|
||||
{
|
||||
"node": "Reshape$",
|
||||
"port": 0
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"node": "Reshape_1$",
|
||||
"port": 0
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"node": "ExpandDims$",
|
||||
"port": 0
|
||||
}
|
||||
]
|
||||
],
|
||||
"instances": [
|
||||
".*SecondStagePostprocessor/"
|
||||
],
|
||||
"match_kind": "scope",
|
||||
"outputs": [
|
||||
{
|
||||
"node": "Cast_3$",
|
||||
"port": 0
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"outputs": "StatefulPartitionedCall/SecondStagePostprocessor/Cast_3"
|
||||
},
|
||||
"id": "ObjectDetectionAPIOutputReplacement",
|
||||
"match_kind": "general"
|
||||
}
|
||||
]
|
@ -98,6 +98,7 @@
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"masks_node_prefix_name": "SecondStageBoxPredictor"
|
||||
},
|
||||
"id": "ObjectDetectionAPIMaskRCNNSigmoidReplacement",
|
||||
"match_kind": "general"
|
||||
@ -117,4 +118,4 @@
|
||||
"id": "ObjectDetectionAPIConstValueOverride",
|
||||
"match_kind": "general"
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -99,6 +99,7 @@
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"masks_node_prefix_name": "SecondStageBoxPredictor"
|
||||
},
|
||||
"id": "ObjectDetectionAPIMaskRCNNSigmoidReplacement",
|
||||
"match_kind": "general"
|
||||
@ -118,4 +119,4 @@
|
||||
"id": "ObjectDetectionAPIConstValueOverride",
|
||||
"match_kind": "general"
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -99,6 +99,7 @@
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"masks_node_prefix_name": "SecondStageBoxPredictor"
|
||||
},
|
||||
"id": "ObjectDetectionAPIMaskRCNNSigmoidReplacement",
|
||||
"match_kind": "general"
|
||||
@ -118,4 +119,4 @@
|
||||
"id": "ObjectDetectionAPIConstValueOverride",
|
||||
"match_kind": "general"
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -99,6 +99,7 @@
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"masks_node_prefix_name": "SecondStageBoxPredictor"
|
||||
},
|
||||
"id": "ObjectDetectionAPIMaskRCNNSigmoidReplacement",
|
||||
"match_kind": "general"
|
||||
|
@ -98,6 +98,7 @@
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"masks_node_prefix_name": "SecondStageBoxPredictor"
|
||||
},
|
||||
"id": "ObjectDetectionAPIMaskRCNNSigmoidReplacement",
|
||||
"match_kind": "general"
|
||||
@ -117,4 +118,4 @@
|
||||
"id": "ObjectDetectionAPIConstValueOverride",
|
||||
"match_kind": "general"
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -0,0 +1,91 @@
|
||||
[
|
||||
{
|
||||
"custom_attributes": {
|
||||
"start_nodes": ["StatefulPartitionedCall/Preprocessor/unstack"],
|
||||
"end_nodes": ["StatefulPartitionedCall/Preprocessor/stack",
|
||||
"StatefulPartitionedCall/Preprocessor/stack_1"]
|
||||
},
|
||||
"id": "ObjectDetectionAPIPreprocessor2Replacement",
|
||||
"match_kind": "general"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"clip_before_nms": false,
|
||||
"clip_after_nms": true
|
||||
},
|
||||
"id": "ObjectDetectionAPIProposalReplacement",
|
||||
"include_inputs_to_sub_graph": true,
|
||||
"include_outputs_to_sub_graph": true,
|
||||
"instances": {
|
||||
"end_points": [
|
||||
"StatefulPartitionedCall/stack_3",
|
||||
"StatefulPartitionedCall/BatchMultiClassNonMaxSuppression/stack_10",
|
||||
"StatefulPartitionedCall/Shape"
|
||||
],
|
||||
"start_points": [
|
||||
"StatefulPartitionedCall/concat/concat",
|
||||
"StatefulPartitionedCall/concat_1/concat",
|
||||
"StatefulPartitionedCall/GridAnchorGenerator/Identity",
|
||||
"StatefulPartitionedCall/Cast_1",
|
||||
"StatefulPartitionedCall/Cast_2",
|
||||
"StatefulPartitionedCall/Shape"
|
||||
]
|
||||
},
|
||||
"match_kind": "points"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"clip_before_nms": false,
|
||||
"clip_after_nms": true,
|
||||
"background_label_id": 0,
|
||||
"coordinates_swap_method": "swap_weights"
|
||||
},
|
||||
"id": "ObjectDetectionAPIDetectionOutputReplacement",
|
||||
"include_inputs_to_sub_graph": true,
|
||||
"include_outputs_to_sub_graph": true,
|
||||
"instances": {
|
||||
"end_points": [
|
||||
"StatefulPartitionedCall/BatchMultiClassNonMaxSuppression_1/stack_8",
|
||||
"StatefulPartitionedCall/BatchMultiClassNonMaxSuppression_1/stack_6"
|
||||
],
|
||||
"start_points": [
|
||||
"StatefulPartitionedCall/Reshape_4",
|
||||
"StatefulPartitionedCall/Reshape_5",
|
||||
"StatefulPartitionedCall/ExpandDims_6",
|
||||
"StatefulPartitionedCall/Cast_5"
|
||||
]
|
||||
},
|
||||
"match_kind": "points"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
},
|
||||
"id": "ObjectDetectionAPIMaskRCNNROIPoolingSecondReplacement",
|
||||
"include_inputs_to_sub_graph": true,
|
||||
"include_outputs_to_sub_graph": true,
|
||||
"instances": {
|
||||
"end_points": [
|
||||
"StatefulPartitionedCall/Reshape_10"
|
||||
],
|
||||
"start_points": [
|
||||
"StatefulPartitionedCall/CropAndResize_1/CropAndResize",
|
||||
"StatefulPartitionedCall/CropAndResize_1/Reshape"
|
||||
]
|
||||
},
|
||||
"match_kind": "points"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"masks_node_prefix_name": "StatefulPartitionedCall/mask_rcnn_keras_box_predictor/mask_rcnn_mask_head/"
|
||||
},
|
||||
"id": "ObjectDetectionAPIMaskRCNNSigmoidReplacement",
|
||||
"match_kind": "general"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"outputs": "StatefulPartitionedCall/mask_rcnn_keras_box_predictor/mask_rcnn_mask_head/MaskPredictor_last_conv2d/BiasAdd,StatefulPartitionedCall/Reshape_13"
|
||||
},
|
||||
"id": "ObjectDetectionAPIOutputReplacement",
|
||||
"match_kind": "general"
|
||||
}
|
||||
]
|
@ -0,0 +1,50 @@
|
||||
[
|
||||
{
|
||||
"custom_attributes": {
|
||||
"start_nodes": ["StatefulPartitionedCall/Preprocessor/unstack"],
|
||||
"end_nodes": ["StatefulPartitionedCall/Preprocessor/stack",
|
||||
"StatefulPartitionedCall/Preprocessor/stack_1"]
|
||||
},
|
||||
"id": "ObjectDetectionAPIPreprocessor2Replacement",
|
||||
"match_kind": "general"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"code_type": "caffe.PriorBoxParameter.CENTER_SIZE",
|
||||
"pad_mode": "caffe.ResizeParameter.CONSTANT",
|
||||
"resize_mode": "caffe.ResizeParameter.WARP",
|
||||
"clip_before_nms": false,
|
||||
"clip_after_nms": true,
|
||||
"disable_prior_boxes_layers_generator": true
|
||||
},
|
||||
"id": "ObjectDetectionAPISSDPostprocessorReplacement",
|
||||
"include_inputs_to_sub_graph": true,
|
||||
"include_outputs_to_sub_graph": true,
|
||||
"instances": {
|
||||
"end_points": [
|
||||
"StatefulPartitionedCall/Identity",
|
||||
"StatefulPartitionedCall/Identity_1",
|
||||
"StatefulPartitionedCall/Identity_2",
|
||||
"StatefulPartitionedCall/Identity_3",
|
||||
"StatefulPartitionedCall/Identity_4",
|
||||
"StatefulPartitionedCall/Identity_5",
|
||||
"StatefulPartitionedCall/Identity_6",
|
||||
"StatefulPartitionedCall/Identity_7"
|
||||
],
|
||||
"start_points": [
|
||||
"StatefulPartitionedCall/Postprocessor/Reshape_1",
|
||||
"StatefulPartitionedCall/Postprocessor/scale_logits",
|
||||
"StatefulPartitionedCall/Postprocessor/Tile",
|
||||
"StatefulPartitionedCall/Postprocessor/Cast_1"
|
||||
]
|
||||
},
|
||||
"match_kind": "points"
|
||||
},
|
||||
{
|
||||
"custom_attributes": {
|
||||
"outputs": "StatefulPartitionedCall/Identity,StatefulPartitionedCall/Identity_1,StatefulPartitionedCall/Identity_2,StatefulPartitionedCall/Identity_3,StatefulPartitionedCall/Identity_4,StatefulPartitionedCall/Identity_5,StatefulPartitionedCall/Identity_6,StatefulPartitionedCall/Identity_7"
|
||||
},
|
||||
"id": "ObjectDetectionAPIOutputReplacement",
|
||||
"match_kind": "general"
|
||||
}
|
||||
]
|
@ -199,7 +199,7 @@ class SubgraphMatcher(object):
|
||||
node_name, self.replacement_desc.id))
|
||||
return None
|
||||
|
||||
matched_nodes = sub_graph_between_nodes(graph, start_points, end_points)
|
||||
matched_nodes = sub_graph_between_nodes(graph, start_points, end_points, include_control_flow=False)
|
||||
return SubgraphMatch(graph, self.replacement_desc, matched_nodes,
|
||||
self.replacement_desc.get_inputs_description(),
|
||||
self.replacement_desc.get_outputs_description(), '')
|
||||
|
@ -229,7 +229,7 @@ class CustomReplacementDescriptorPoints(CustomReplacementDescriptor):
|
||||
start_points = self.get_internal_input_nodes(graph)
|
||||
end_points = self.get_internal_output_nodes(graph)
|
||||
|
||||
matched_nodes = sub_graph_between_nodes(graph, start_points, end_points)
|
||||
matched_nodes = sub_graph_between_nodes(graph, start_points, end_points, include_control_flow=False)
|
||||
output_tensors = set()
|
||||
input_nodes_mapping = dict() # key is the input tensor name, value is the pair: (input_port, output_node_name)
|
||||
for src_node_name, dst_node_name, edge_attrs in graph.edges(data=True):
|
||||
|
@ -117,13 +117,17 @@ def is_connected_component(graph: Graph, node_names: list):
|
||||
return set(node_names).issubset(visited)
|
||||
|
||||
|
||||
def sub_graph_between_nodes(graph: Graph, start_nodes: list, end_nodes: list, detect_extra_start_node: callable=None):
|
||||
def sub_graph_between_nodes(graph: Graph, start_nodes: list, end_nodes: list, detect_extra_start_node: callable=None,
|
||||
include_control_flow=True):
|
||||
"""
|
||||
Finds nodes of the sub-graph between 'start_nodes' and 'end_nodes'. Input nodes for the sub-graph nodes are also
|
||||
added to the sub-graph. Constant inputs of the 'start_nodes' are also added to the sub-graph.
|
||||
:param graph: graph to operate on.
|
||||
:param start_nodes: list of nodes names that specifies start nodes.
|
||||
:param end_nodes: list of nodes names that specifies end nodes.
|
||||
:param detect_extra_start_node: callable function to add additional nodes to the list of start nodes instead of
|
||||
traversing the graph further. The list of additional start nodes is returned of the function is not None.
|
||||
:param include_control_flow: flag to specify whether to follow the control flow edges or not
|
||||
:return: list of nodes of the identified sub-graph or None if the sub-graph cannot be extracted.
|
||||
"""
|
||||
sub_graph_nodes = list()
|
||||
@ -133,23 +137,24 @@ def sub_graph_between_nodes(graph: Graph, start_nodes: list, end_nodes: list, de
|
||||
|
||||
nx.set_node_attributes(G=graph, name='prev', values=None)
|
||||
while len(d) != 0:
|
||||
cur_node_name = d.popleft()
|
||||
sub_graph_nodes.append(cur_node_name)
|
||||
if cur_node_name not in end_nodes: # do not add output nodes of the end_nodes
|
||||
for _, dst_node_name in graph.out_edges(cur_node_name):
|
||||
if dst_node_name not in visited:
|
||||
cur_node_id = d.popleft()
|
||||
sub_graph_nodes.append(cur_node_id)
|
||||
if cur_node_id not in end_nodes: # do not add output nodes of the end_nodes
|
||||
for _, dst_node_name, attrs in graph.out_edges(cur_node_id, data=True):
|
||||
if dst_node_name not in visited and (include_control_flow or not attrs.get('control_flow_edge', False)):
|
||||
d.append(dst_node_name)
|
||||
visited.add(dst_node_name)
|
||||
graph.node[dst_node_name]['prev'] = cur_node_name
|
||||
graph.node[dst_node_name]['prev'] = cur_node_id
|
||||
|
||||
for src_node_name, _ in graph.in_edges(cur_node_name):
|
||||
for src_node_name, _, attrs in graph.in_edges(cur_node_id, data=True):
|
||||
# add input nodes for the non-start_nodes
|
||||
if cur_node_name not in start_nodes and src_node_name not in visited:
|
||||
if detect_extra_start_node is not None and detect_extra_start_node(Node(graph, cur_node_name)):
|
||||
extra_start_nodes.append(cur_node_name)
|
||||
if cur_node_id not in start_nodes and src_node_name not in visited and\
|
||||
(include_control_flow or not attrs.get('control_flow_edge', False)):
|
||||
if detect_extra_start_node is not None and detect_extra_start_node(Node(graph, cur_node_id)):
|
||||
extra_start_nodes.append(cur_node_id)
|
||||
else:
|
||||
d.append(src_node_name)
|
||||
graph.node[src_node_name]['prev'] = cur_node_name
|
||||
graph.node[src_node_name]['prev'] = cur_node_id
|
||||
visited.add(src_node_name)
|
||||
|
||||
# use forward dfs to check that all end nodes are reachable from at least one of input nodes
|
||||
@ -161,16 +166,16 @@ def sub_graph_between_nodes(graph: Graph, start_nodes: list, end_nodes: list, de
|
||||
raise Error('End node "{}" is not reachable from start nodes: {}. '.format(end_node, start_nodes) +
|
||||
refer_to_faq_msg(74))
|
||||
|
||||
for node_name in sub_graph_nodes:
|
||||
for node_id in sub_graph_nodes:
|
||||
# sub-graph should not contain Placeholder nodes
|
||||
if graph.node[node_name].get('op', '') == 'Parameter':
|
||||
if graph.node[node_id].get('op', '') == 'Parameter':
|
||||
path = list()
|
||||
cur_node = node_name
|
||||
cur_node = node_id
|
||||
while cur_node and 'prev' in graph.node[cur_node]:
|
||||
path.append(str(cur_node))
|
||||
cur_node = graph.node[cur_node]['prev']
|
||||
log.debug("The path from input node is the following: {}".format('\n'.join(path)))
|
||||
raise Error('The matched sub-graph contains network input node "{}". '.format(node_name) +
|
||||
raise Error('The matched sub-graph contains network input node "{}". '.format(node_id) +
|
||||
refer_to_faq_msg(75))
|
||||
if detect_extra_start_node is None:
|
||||
return sub_graph_nodes
|
||||
|
@ -211,3 +211,59 @@ class TestGraphUtils(unittest.TestCase):
|
||||
# after merging node 2 into sub-graph the node 2 will be removed and it is not known how to calculate the tensor
|
||||
# between node 2 and 3.
|
||||
self.assertListEqual(sorted(sub_graph_between_nodes(graph, [2], [8])), [n for n in node_names if n != 1])
|
||||
|
||||
def test_sub_graph_between_nodes_control_flow_included(self):
|
||||
"""
|
||||
Check that the function works correctly for case when control flow edges must be traversed (edge 5 -> 2).
|
||||
6 -> 5->
|
||||
\
|
||||
1 -> 2 -> 3 -> 4
|
||||
"""
|
||||
graph = Graph()
|
||||
graph.add_nodes_from(list(range(1, 7)))
|
||||
graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2, {'control_flow_edge': True}), (6, 5)])
|
||||
sub_graph_nodes = sub_graph_between_nodes(graph, [1], [4], include_control_flow=True)
|
||||
self.assertIsNotNone(sub_graph_nodes)
|
||||
self.assertListEqual(sorted(sub_graph_nodes), sorted([1, 2, 3, 4, 5, 6]))
|
||||
|
||||
def test_sub_graph_between_nodes_control_flow_not_included(self):
|
||||
"""
|
||||
Check that the function works correctly for case when control flow edges should not be traversed (edge 5 -> 2).
|
||||
6 -> 5->
|
||||
\
|
||||
1 -> 2 -> 3 -> 4
|
||||
"""
|
||||
graph = Graph()
|
||||
graph.add_nodes_from(list(range(1, 7)))
|
||||
graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2, {'control_flow_edge': True}), (6, 5)])
|
||||
sub_graph_nodes = sub_graph_between_nodes(graph, [1], [4], include_control_flow=False)
|
||||
self.assertIsNotNone(sub_graph_nodes)
|
||||
self.assertListEqual(sorted(sub_graph_nodes), sorted([1, 2, 3, 4]))
|
||||
|
||||
def test_sub_graph_between_nodes_control_flow_included_forward(self):
|
||||
"""
|
||||
Check that the function works correctly for case when control flow edges should not be traversed (edge 3 -> 5).
|
||||
1 -> 2 -> 3 -> 4
|
||||
\
|
||||
-> 5 -> 6
|
||||
"""
|
||||
graph = Graph()
|
||||
graph.add_nodes_from(list(range(1, 7)))
|
||||
graph.add_edges_from([(1, 2), (2, 3), (3, 4), (3, 5, {'control_flow_edge': True}), (5, 6)])
|
||||
sub_graph_nodes = sub_graph_between_nodes(graph, [1], [4], include_control_flow=True)
|
||||
self.assertIsNotNone(sub_graph_nodes)
|
||||
self.assertListEqual(sorted(sub_graph_nodes), sorted([1, 2, 3, 4, 5, 6]))
|
||||
|
||||
def test_sub_graph_between_nodes_control_flow_not_included_forward(self):
|
||||
"""
|
||||
Check that the function works correctly for case when control flow edges should not be traversed (edge 3 -> 5).
|
||||
1 -> 2 -> 3 -> 4
|
||||
\
|
||||
-> 5 -> 6
|
||||
"""
|
||||
graph = Graph()
|
||||
graph.add_nodes_from(list(range(1, 7)))
|
||||
graph.add_edges_from([(1, 2), (2, 3), (3, 4), (3, 5, {'control_flow_edge': True}), (5, 6)])
|
||||
sub_graph_nodes = sub_graph_between_nodes(graph, [1], [4], include_control_flow=False)
|
||||
self.assertIsNotNone(sub_graph_nodes)
|
||||
self.assertListEqual(sorted(sub_graph_nodes), sorted([1, 2, 3, 4]))
|
||||
|
@ -64,12 +64,15 @@ mapping_rules = [
|
||||
('crop_height', '.*/rfcn_box_predictor/crop_height'),
|
||||
('crop_width', '.*/rfcn_box_predictor/crop_width'),
|
||||
'initial_crop_size',
|
||||
('use_matmul_crop_and_resize', 'use_matmul_crop_and_resize', False),
|
||||
('add_background_class', 'add_background_class', True),
|
||||
# Detection Output layer attributes
|
||||
('postprocessing_score_converter', '.*/score_converter'),
|
||||
('postprocessing_score_threshold', '.*/batch_non_max_suppression/score_threshold'),
|
||||
('postprocessing_iou_threshold', '.*/batch_non_max_suppression/iou_threshold'),
|
||||
('postprocessing_max_detections_per_class', '.*/batch_non_max_suppression/max_detections_per_class'),
|
||||
('postprocessing_max_total_detections', '.*/batch_non_max_suppression/max_total_detections'),
|
||||
('share_box_across_classes', 'second_stage_box_predictor/.*/share_box_across_classes$', False),
|
||||
# Variances for predicted bounding box deltas (tx, ty, tw, th)
|
||||
('frcnn_variance_x', 'box_coder/faster_rcnn_box_coder/x_scale', 10.0),
|
||||
('frcnn_variance_y', 'box_coder/faster_rcnn_box_coder/y_scale', 10.0),
|
||||
|
@ -148,6 +148,9 @@ class TestingSimpleProtoParser(unittest.TestCase):
|
||||
'ssd_anchor_generator_min_scale': 0.2,
|
||||
'ssd_anchor_generator_max_scale': 0.95,
|
||||
'ssd_anchor_generator_interpolated_scale_aspect_ratio': 1.0,
|
||||
'use_matmul_crop_and_resize': False,
|
||||
'add_background_class': True,
|
||||
'share_box_across_classes': False,
|
||||
}
|
||||
os.unlink(file_name)
|
||||
self.assertDictEqual(pipeline_config._model_params, expected_result)
|
||||
|
@ -23,12 +23,9 @@ from mo.utils.version import get_version
|
||||
|
||||
|
||||
class TestingVersion(unittest.TestCase):
|
||||
def test_unknown_version(self):
|
||||
self.assertNotEqual(get_version(), "unknown version")
|
||||
|
||||
@patch('os.path.isfile')
|
||||
@mock.patch('builtins.open', new_callable=mock_open, create=True, read_data='2021.1.0-1028-55e4d5673a8')
|
||||
def test_get_version(self, mock_open, mock_isfile):
|
||||
mock_isfile.return_value = True
|
||||
mock_open.return_value.__enter__ = mock_open
|
||||
self.assertEqual(get_version(), '2021.1.0-1028-55e4d5673a8')
|
||||
self.assertEqual(get_version(), '2021.1.0-1028-55e4d5673a8')
|
||||
|
Loading…
Reference in New Issue
Block a user