From 9d53b3536d82c974382b25ba9489276c71ba8102 Mon Sep 17 00:00:00 2001 From: Yegor Kruglov Date: Fri, 10 Sep 2021 17:44:42 +0300 Subject: [PATCH] [MO] Updating MO to detect TF 2.X OD API models (#6983) * updated FasterRCNN and SSD analysis patterns * updated tf od api conditions * updated ssd patterns * added more ssd topologies * move preprocessor to tf od api condition * update TF OD API conditions * refactoring * specify data type --- .../extensions/analysis/tf_od_api.py | 107 ++++++++++-------- model-optimizer/mo/utils/model_analysis.py | 13 ++- 2 files changed, 66 insertions(+), 54 deletions(-) diff --git a/model-optimizer/extensions/analysis/tf_od_api.py b/model-optimizer/extensions/analysis/tf_od_api.py index 8fe6d78964a..214f8963694 100644 --- a/model-optimizer/extensions/analysis/tf_od_api.py +++ b/model-optimizer/extensions/analysis/tf_od_api.py @@ -16,62 +16,73 @@ class TensorFlowObjectDetectionAPIAnalysis(AnalyzeAction): """ graph_condition = [lambda graph: graph.graph['fw'] == 'tf'] - model_scopes = [('MaskRCNN', ['Preprocessor', - 'FirstStageFeatureExtractor', - 'SecondStageFeatureExtractor', - 'SecondStageBoxPredictor', - 'SecondStageBoxPredictor_1', - 'SecondStageFeatureExtractor_1', - ]), - ('RFCN', ['Preprocessor', - 'FirstStageFeatureExtractor', - 'SecondStageFeatureExtractor', - 'SecondStageBoxPredictor', - 'SecondStageBoxPredictor/map', - 'SecondStageBoxPredictor/map_1', - 'SecondStagePostprocessor', - ]), - ('FasterRCNN', ['Preprocessor', - 'FirstStageFeatureExtractor', - 'SecondStageFeatureExtractor', - 'SecondStageBoxPredictor', - 'SecondStagePostprocessor', - ]), - ('SSD', ['Preprocessor', - 'FeatureExtractor', - 'Postprocessor', - ]), - ] - file_patterns = {'MaskRCNN': 'mask_rcnn_support.*\\.json', 'RFCN': 'rfcn_support.*\\.json', 'FasterRCNN': 'faster_rcnn_support.*\\.json', 'SSD': 'ssd.*_support.*\\.json', } + model_scopes = {'MaskRCNN': (['FirstStageFeatureExtractor', + 'SecondStageFeatureExtractor', + 'SecondStageBoxPredictor', + 'SecondStageBoxPredictor_1', + 'SecondStageFeatureExtractor_1', + ],), + 'RFCN': (['FirstStageFeatureExtractor', + 'SecondStageFeatureExtractor', + 'SecondStageBoxPredictor', + 'SecondStageBoxPredictor/map', + 'SecondStageBoxPredictor/map_1', + 'SecondStagePostprocessor', + ],), + 'FasterRCNN': (['FirstStageFeatureExtractor', + 'SecondStageFeatureExtractor', + 'SecondStageBoxPredictor', + 'SecondStagePostprocessor', + ], + ['FirstStageRPNFeatures', + 'FirstStageBoxPredictor', + 'SecondStagePostprocessor', + 'mask_rcnn_keras_box_predictor', + ],), + 'SSD': ([('FeatureExtractor', 'ssd_mobile_net_v2keras_feature_extractor', + 'ssd_mobile_net_v1fpn_keras_feature_extractor', + 'ssd_mobile_net_v2fpn_keras_feature_extractor', 'ResNet50V1_FPN', 'ResNet101V1_FPN', + 'ResNet152V1_FPN' + ), + 'Postprocessor'] + ), + } def analyze(self, graph: Graph): - if any([name not in graph.nodes() for name in ['image_tensor', 'detection_classes', 'detection_boxes', - 'detection_scores']]): + tf_1_names = ['image_tensor', 'detection_classes', 'detection_boxes', 'detection_scores', + ('Preprocessor', 'map')] + tf_1_cond = all([graph_contains_scope(graph, scope) for scope in tf_1_names]) + + tf_2_names = ['input_tensor', 'output_control_node', 'Identity_', ('Preprocessor', 'map')] + tf_2_cond = all([graph_contains_scope(graph, scope) for scope in tf_2_names]) + + if not tf_1_cond and not tf_2_cond: log.debug('The model does not contain nodes that must exist in the TF OD API models') return None, None - for flavor, scopes in __class__.model_scopes: - if all([graph_contains_scope(graph, scope) for scope in scopes]): - result = dict() - result['flavor'] = flavor - result['mandatory_parameters'] = {'tensorflow_use_custom_operations_config': - files_by_pattern(get_mo_root_dir() + '/extensions/front/tf', - __class__.file_patterns[flavor], - add_prefix=True), - 'tensorflow_object_detection_api_pipeline_config': None, - } - message = "Your model looks like TensorFlow Object Detection API Model.\n" \ - "Check if all parameters are specified:\n" \ - "\t--tensorflow_use_custom_operations_config\n" \ - "\t--tensorflow_object_detection_api_pipeline_config\n" \ - "\t--input_shape (optional)\n" \ - "\t--reverse_input_channels (if you convert a model to use with the Inference Engine sample applications)\n" \ - "Detailed information about conversion of this model can be found at\n" \ - "https://docs.openvinotoolkit.org/latest/_docs_MO_DG_prepare_model_convert_model_tf_specific_Convert_Object_Detection_API_Models.html" - return {'model_type': {'TF_OD_API': result}}, message + for flavor, scopes_tuple in self.model_scopes.items(): + for scopes in scopes_tuple: + if all([graph_contains_scope(graph, scope) for scope in scopes]): + result = dict() + result['flavor'] = flavor + result['mandatory_parameters'] = {'tensorflow_use_custom_operations_config': + files_by_pattern(get_mo_root_dir() + '/extensions/front/tf', + __class__.file_patterns[flavor], + add_prefix=True), + 'tensorflow_object_detection_api_pipeline_config': None, + } + message = "Your model looks like TensorFlow Object Detection API Model.\n" \ + "Check if all parameters are specified:\n" \ + "\t--transformations_config\n" \ + "\t--tensorflow_object_detection_api_pipeline_config\n" \ + "\t--input_shape (optional)\n" \ + "\t--reverse_input_channels (if you convert a model to use with the Inference Engine sample applications)\n" \ + "Detailed information about conversion of this model can be found at\n" \ + "https://docs.openvinotoolkit.org/latest/_docs_MO_DG_prepare_model_convert_model_tf_specific_Convert_Object_Detection_API_Models.html" + return {'model_type': {'TF_OD_API': result}}, message return None, None diff --git a/model-optimizer/mo/utils/model_analysis.py b/model-optimizer/mo/utils/model_analysis.py index 2049f8b41ec..4fe47637b7b 100644 --- a/model-optimizer/mo/utils/model_analysis.py +++ b/model-optimizer/mo/utils/model_analysis.py @@ -112,13 +112,14 @@ class AnalysisCollectorAnchor(AnalyzeAction): pass -def graph_contains_scope(graph: Graph, scope: str): +def graph_contains_scope(graph: Graph, scope: [str, tuple]): """ - Checks whether the graph contains node(s) which name starts with "scope" string. + Checks whether the graph contains node(s) which name includes "scope" string. :param graph: graph to check - :param scope: string defining the scope + :param scope: string or tuple with strings defining the scope :return: the result of the check (True/False) """ - if scope[-1] != '/': - scope += '/' - return any([node.soft_get('name').startswith(scope) for node in graph.get_op_nodes()]) + if type(scope) is str: + return any([node.soft_get('name').find(scope) != -1 for node in graph.get_op_nodes()]) + else: + return any([graph_contains_scope(graph, s) for s in scope])