[MO] Updating MO to detect TF 2.X OD API models (#6983)

* updated FasterRCNN and SSD analysis patterns

* updated tf od api conditions

* updated ssd patterns

* added more ssd topologies

* move preprocessor to tf od api condition

* update TF OD API conditions

* refactoring

* specify data type
This commit is contained in:
Yegor Kruglov 2021-09-10 17:44:42 +03:00 committed by GitHub
parent 754ee2eb1a
commit 9d53b3536d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 54 deletions

View File

@ -16,62 +16,73 @@ class TensorFlowObjectDetectionAPIAnalysis(AnalyzeAction):
""" """
graph_condition = [lambda graph: graph.graph['fw'] == 'tf'] graph_condition = [lambda graph: graph.graph['fw'] == 'tf']
model_scopes = [('MaskRCNN', ['Preprocessor',
'FirstStageFeatureExtractor',
'SecondStageFeatureExtractor',
'SecondStageBoxPredictor',
'SecondStageBoxPredictor_1',
'SecondStageFeatureExtractor_1',
]),
('RFCN', ['Preprocessor',
'FirstStageFeatureExtractor',
'SecondStageFeatureExtractor',
'SecondStageBoxPredictor',
'SecondStageBoxPredictor/map',
'SecondStageBoxPredictor/map_1',
'SecondStagePostprocessor',
]),
('FasterRCNN', ['Preprocessor',
'FirstStageFeatureExtractor',
'SecondStageFeatureExtractor',
'SecondStageBoxPredictor',
'SecondStagePostprocessor',
]),
('SSD', ['Preprocessor',
'FeatureExtractor',
'Postprocessor',
]),
]
file_patterns = {'MaskRCNN': 'mask_rcnn_support.*\\.json', file_patterns = {'MaskRCNN': 'mask_rcnn_support.*\\.json',
'RFCN': 'rfcn_support.*\\.json', 'RFCN': 'rfcn_support.*\\.json',
'FasterRCNN': 'faster_rcnn_support.*\\.json', 'FasterRCNN': 'faster_rcnn_support.*\\.json',
'SSD': 'ssd.*_support.*\\.json', 'SSD': 'ssd.*_support.*\\.json',
} }
model_scopes = {'MaskRCNN': (['FirstStageFeatureExtractor',
'SecondStageFeatureExtractor',
'SecondStageBoxPredictor',
'SecondStageBoxPredictor_1',
'SecondStageFeatureExtractor_1',
],),
'RFCN': (['FirstStageFeatureExtractor',
'SecondStageFeatureExtractor',
'SecondStageBoxPredictor',
'SecondStageBoxPredictor/map',
'SecondStageBoxPredictor/map_1',
'SecondStagePostprocessor',
],),
'FasterRCNN': (['FirstStageFeatureExtractor',
'SecondStageFeatureExtractor',
'SecondStageBoxPredictor',
'SecondStagePostprocessor',
],
['FirstStageRPNFeatures',
'FirstStageBoxPredictor',
'SecondStagePostprocessor',
'mask_rcnn_keras_box_predictor',
],),
'SSD': ([('FeatureExtractor', 'ssd_mobile_net_v2keras_feature_extractor',
'ssd_mobile_net_v1fpn_keras_feature_extractor',
'ssd_mobile_net_v2fpn_keras_feature_extractor', 'ResNet50V1_FPN', 'ResNet101V1_FPN',
'ResNet152V1_FPN'
),
'Postprocessor']
),
}
def analyze(self, graph: Graph): def analyze(self, graph: Graph):
if any([name not in graph.nodes() for name in ['image_tensor', 'detection_classes', 'detection_boxes', tf_1_names = ['image_tensor', 'detection_classes', 'detection_boxes', 'detection_scores',
'detection_scores']]): ('Preprocessor', 'map')]
tf_1_cond = all([graph_contains_scope(graph, scope) for scope in tf_1_names])
tf_2_names = ['input_tensor', 'output_control_node', 'Identity_', ('Preprocessor', 'map')]
tf_2_cond = all([graph_contains_scope(graph, scope) for scope in tf_2_names])
if not tf_1_cond and not tf_2_cond:
log.debug('The model does not contain nodes that must exist in the TF OD API models') log.debug('The model does not contain nodes that must exist in the TF OD API models')
return None, None return None, None
for flavor, scopes in __class__.model_scopes: for flavor, scopes_tuple in self.model_scopes.items():
if all([graph_contains_scope(graph, scope) for scope in scopes]): for scopes in scopes_tuple:
result = dict() if all([graph_contains_scope(graph, scope) for scope in scopes]):
result['flavor'] = flavor result = dict()
result['mandatory_parameters'] = {'tensorflow_use_custom_operations_config': result['flavor'] = flavor
files_by_pattern(get_mo_root_dir() + '/extensions/front/tf', result['mandatory_parameters'] = {'tensorflow_use_custom_operations_config':
__class__.file_patterns[flavor], files_by_pattern(get_mo_root_dir() + '/extensions/front/tf',
add_prefix=True), __class__.file_patterns[flavor],
'tensorflow_object_detection_api_pipeline_config': None, add_prefix=True),
} 'tensorflow_object_detection_api_pipeline_config': None,
message = "Your model looks like TensorFlow Object Detection API Model.\n" \ }
"Check if all parameters are specified:\n" \ message = "Your model looks like TensorFlow Object Detection API Model.\n" \
"\t--tensorflow_use_custom_operations_config\n" \ "Check if all parameters are specified:\n" \
"\t--tensorflow_object_detection_api_pipeline_config\n" \ "\t--transformations_config\n" \
"\t--input_shape (optional)\n" \ "\t--tensorflow_object_detection_api_pipeline_config\n" \
"\t--reverse_input_channels (if you convert a model to use with the Inference Engine sample applications)\n" \ "\t--input_shape (optional)\n" \
"Detailed information about conversion of this model can be found at\n" \ "\t--reverse_input_channels (if you convert a model to use with the Inference Engine sample applications)\n" \
"https://docs.openvinotoolkit.org/latest/_docs_MO_DG_prepare_model_convert_model_tf_specific_Convert_Object_Detection_API_Models.html" "Detailed information about conversion of this model can be found at\n" \
return {'model_type': {'TF_OD_API': result}}, message "https://docs.openvinotoolkit.org/latest/_docs_MO_DG_prepare_model_convert_model_tf_specific_Convert_Object_Detection_API_Models.html"
return {'model_type': {'TF_OD_API': result}}, message
return None, None return None, None

View File

@ -112,13 +112,14 @@ class AnalysisCollectorAnchor(AnalyzeAction):
pass pass
def graph_contains_scope(graph: Graph, scope: str): def graph_contains_scope(graph: Graph, scope: [str, tuple]):
""" """
Checks whether the graph contains node(s) which name starts with "scope" string. Checks whether the graph contains node(s) which name includes "scope" string.
:param graph: graph to check :param graph: graph to check
:param scope: string defining the scope :param scope: string or tuple with strings defining the scope
:return: the result of the check (True/False) :return: the result of the check (True/False)
""" """
if scope[-1] != '/': if type(scope) is str:
scope += '/' return any([node.soft_get('name').find(scope) != -1 for node in graph.get_op_nodes()])
return any([node.soft_get('name').startswith(scope) for node in graph.get_op_nodes()]) else:
return any([graph_contains_scope(graph, s) for s in scope])