[MO] Updating MO to detect TF 2.X OD API models (#6983)

* updated FasterRCNN and SSD analysis patterns

* updated tf od api conditions

* updated ssd patterns

* added more ssd topologies

* move preprocessor to tf od api condition

* update TF OD API conditions

* refactoring

* specify data type
This commit is contained in:
Yegor Kruglov 2021-09-10 17:44:42 +03:00 committed by GitHub
parent 754ee2eb1a
commit 9d53b3536d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 54 deletions

View File

@ -16,46 +16,57 @@ class TensorFlowObjectDetectionAPIAnalysis(AnalyzeAction):
"""
graph_condition = [lambda graph: graph.graph['fw'] == 'tf']
model_scopes = [('MaskRCNN', ['Preprocessor',
'FirstStageFeatureExtractor',
'SecondStageFeatureExtractor',
'SecondStageBoxPredictor',
'SecondStageBoxPredictor_1',
'SecondStageFeatureExtractor_1',
]),
('RFCN', ['Preprocessor',
'FirstStageFeatureExtractor',
'SecondStageFeatureExtractor',
'SecondStageBoxPredictor',
'SecondStageBoxPredictor/map',
'SecondStageBoxPredictor/map_1',
'SecondStagePostprocessor',
]),
('FasterRCNN', ['Preprocessor',
'FirstStageFeatureExtractor',
'SecondStageFeatureExtractor',
'SecondStageBoxPredictor',
'SecondStagePostprocessor',
]),
('SSD', ['Preprocessor',
'FeatureExtractor',
'Postprocessor',
]),
]
file_patterns = {'MaskRCNN': 'mask_rcnn_support.*\\.json',
'RFCN': 'rfcn_support.*\\.json',
'FasterRCNN': 'faster_rcnn_support.*\\.json',
'SSD': 'ssd.*_support.*\\.json',
}
model_scopes = {'MaskRCNN': (['FirstStageFeatureExtractor',
'SecondStageFeatureExtractor',
'SecondStageBoxPredictor',
'SecondStageBoxPredictor_1',
'SecondStageFeatureExtractor_1',
],),
'RFCN': (['FirstStageFeatureExtractor',
'SecondStageFeatureExtractor',
'SecondStageBoxPredictor',
'SecondStageBoxPredictor/map',
'SecondStageBoxPredictor/map_1',
'SecondStagePostprocessor',
],),
'FasterRCNN': (['FirstStageFeatureExtractor',
'SecondStageFeatureExtractor',
'SecondStageBoxPredictor',
'SecondStagePostprocessor',
],
['FirstStageRPNFeatures',
'FirstStageBoxPredictor',
'SecondStagePostprocessor',
'mask_rcnn_keras_box_predictor',
],),
'SSD': ([('FeatureExtractor', 'ssd_mobile_net_v2keras_feature_extractor',
'ssd_mobile_net_v1fpn_keras_feature_extractor',
'ssd_mobile_net_v2fpn_keras_feature_extractor', 'ResNet50V1_FPN', 'ResNet101V1_FPN',
'ResNet152V1_FPN'
),
'Postprocessor']
),
}
def analyze(self, graph: Graph):
if any([name not in graph.nodes() for name in ['image_tensor', 'detection_classes', 'detection_boxes',
'detection_scores']]):
tf_1_names = ['image_tensor', 'detection_classes', 'detection_boxes', 'detection_scores',
('Preprocessor', 'map')]
tf_1_cond = all([graph_contains_scope(graph, scope) for scope in tf_1_names])
tf_2_names = ['input_tensor', 'output_control_node', 'Identity_', ('Preprocessor', 'map')]
tf_2_cond = all([graph_contains_scope(graph, scope) for scope in tf_2_names])
if not tf_1_cond and not tf_2_cond:
log.debug('The model does not contain nodes that must exist in the TF OD API models')
return None, None
for flavor, scopes in __class__.model_scopes:
for flavor, scopes_tuple in self.model_scopes.items():
for scopes in scopes_tuple:
if all([graph_contains_scope(graph, scope) for scope in scopes]):
result = dict()
result['flavor'] = flavor
@ -67,7 +78,7 @@ class TensorFlowObjectDetectionAPIAnalysis(AnalyzeAction):
}
message = "Your model looks like TensorFlow Object Detection API Model.\n" \
"Check if all parameters are specified:\n" \
"\t--tensorflow_use_custom_operations_config\n" \
"\t--transformations_config\n" \
"\t--tensorflow_object_detection_api_pipeline_config\n" \
"\t--input_shape (optional)\n" \
"\t--reverse_input_channels (if you convert a model to use with the Inference Engine sample applications)\n" \

View File

@ -112,13 +112,14 @@ class AnalysisCollectorAnchor(AnalyzeAction):
pass
def graph_contains_scope(graph: Graph, scope: str):
def graph_contains_scope(graph: Graph, scope: [str, tuple]):
"""
Checks whether the graph contains node(s) which name starts with "scope" string.
Checks whether the graph contains node(s) which name includes "scope" string.
:param graph: graph to check
:param scope: string defining the scope
:param scope: string or tuple with strings defining the scope
:return: the result of the check (True/False)
"""
if scope[-1] != '/':
scope += '/'
return any([node.soft_get('name').startswith(scope) for node in graph.get_op_nodes()])
if type(scope) is str:
return any([node.soft_get('name').find(scope) != -1 for node in graph.get_op_nodes()])
else:
return any([graph_contains_scope(graph, s) for s in scope])