[MO] Updating MO to detect TF 2.X OD API models (#6983)
* updated FasterRCNN and SSD analysis patterns * updated tf od api conditions * updated ssd patterns * added more ssd topologies * move preprocessor to tf od api condition * update TF OD API conditions * refactoring * specify data type
This commit is contained in:
parent
754ee2eb1a
commit
9d53b3536d
@ -16,46 +16,57 @@ class TensorFlowObjectDetectionAPIAnalysis(AnalyzeAction):
|
||||
"""
|
||||
graph_condition = [lambda graph: graph.graph['fw'] == 'tf']
|
||||
|
||||
model_scopes = [('MaskRCNN', ['Preprocessor',
|
||||
'FirstStageFeatureExtractor',
|
||||
'SecondStageFeatureExtractor',
|
||||
'SecondStageBoxPredictor',
|
||||
'SecondStageBoxPredictor_1',
|
||||
'SecondStageFeatureExtractor_1',
|
||||
]),
|
||||
('RFCN', ['Preprocessor',
|
||||
'FirstStageFeatureExtractor',
|
||||
'SecondStageFeatureExtractor',
|
||||
'SecondStageBoxPredictor',
|
||||
'SecondStageBoxPredictor/map',
|
||||
'SecondStageBoxPredictor/map_1',
|
||||
'SecondStagePostprocessor',
|
||||
]),
|
||||
('FasterRCNN', ['Preprocessor',
|
||||
'FirstStageFeatureExtractor',
|
||||
'SecondStageFeatureExtractor',
|
||||
'SecondStageBoxPredictor',
|
||||
'SecondStagePostprocessor',
|
||||
]),
|
||||
('SSD', ['Preprocessor',
|
||||
'FeatureExtractor',
|
||||
'Postprocessor',
|
||||
]),
|
||||
]
|
||||
|
||||
file_patterns = {'MaskRCNN': 'mask_rcnn_support.*\\.json',
|
||||
'RFCN': 'rfcn_support.*\\.json',
|
||||
'FasterRCNN': 'faster_rcnn_support.*\\.json',
|
||||
'SSD': 'ssd.*_support.*\\.json',
|
||||
}
|
||||
model_scopes = {'MaskRCNN': (['FirstStageFeatureExtractor',
|
||||
'SecondStageFeatureExtractor',
|
||||
'SecondStageBoxPredictor',
|
||||
'SecondStageBoxPredictor_1',
|
||||
'SecondStageFeatureExtractor_1',
|
||||
],),
|
||||
'RFCN': (['FirstStageFeatureExtractor',
|
||||
'SecondStageFeatureExtractor',
|
||||
'SecondStageBoxPredictor',
|
||||
'SecondStageBoxPredictor/map',
|
||||
'SecondStageBoxPredictor/map_1',
|
||||
'SecondStagePostprocessor',
|
||||
],),
|
||||
'FasterRCNN': (['FirstStageFeatureExtractor',
|
||||
'SecondStageFeatureExtractor',
|
||||
'SecondStageBoxPredictor',
|
||||
'SecondStagePostprocessor',
|
||||
],
|
||||
['FirstStageRPNFeatures',
|
||||
'FirstStageBoxPredictor',
|
||||
'SecondStagePostprocessor',
|
||||
'mask_rcnn_keras_box_predictor',
|
||||
],),
|
||||
'SSD': ([('FeatureExtractor', 'ssd_mobile_net_v2keras_feature_extractor',
|
||||
'ssd_mobile_net_v1fpn_keras_feature_extractor',
|
||||
'ssd_mobile_net_v2fpn_keras_feature_extractor', 'ResNet50V1_FPN', 'ResNet101V1_FPN',
|
||||
'ResNet152V1_FPN'
|
||||
),
|
||||
'Postprocessor']
|
||||
),
|
||||
}
|
||||
|
||||
def analyze(self, graph: Graph):
|
||||
if any([name not in graph.nodes() for name in ['image_tensor', 'detection_classes', 'detection_boxes',
|
||||
'detection_scores']]):
|
||||
tf_1_names = ['image_tensor', 'detection_classes', 'detection_boxes', 'detection_scores',
|
||||
('Preprocessor', 'map')]
|
||||
tf_1_cond = all([graph_contains_scope(graph, scope) for scope in tf_1_names])
|
||||
|
||||
tf_2_names = ['input_tensor', 'output_control_node', 'Identity_', ('Preprocessor', 'map')]
|
||||
tf_2_cond = all([graph_contains_scope(graph, scope) for scope in tf_2_names])
|
||||
|
||||
if not tf_1_cond and not tf_2_cond:
|
||||
log.debug('The model does not contain nodes that must exist in the TF OD API models')
|
||||
return None, None
|
||||
|
||||
for flavor, scopes in __class__.model_scopes:
|
||||
for flavor, scopes_tuple in self.model_scopes.items():
|
||||
for scopes in scopes_tuple:
|
||||
if all([graph_contains_scope(graph, scope) for scope in scopes]):
|
||||
result = dict()
|
||||
result['flavor'] = flavor
|
||||
@ -67,7 +78,7 @@ class TensorFlowObjectDetectionAPIAnalysis(AnalyzeAction):
|
||||
}
|
||||
message = "Your model looks like TensorFlow Object Detection API Model.\n" \
|
||||
"Check if all parameters are specified:\n" \
|
||||
"\t--tensorflow_use_custom_operations_config\n" \
|
||||
"\t--transformations_config\n" \
|
||||
"\t--tensorflow_object_detection_api_pipeline_config\n" \
|
||||
"\t--input_shape (optional)\n" \
|
||||
"\t--reverse_input_channels (if you convert a model to use with the Inference Engine sample applications)\n" \
|
||||
|
@ -112,13 +112,14 @@ class AnalysisCollectorAnchor(AnalyzeAction):
|
||||
pass
|
||||
|
||||
|
||||
def graph_contains_scope(graph: Graph, scope: str):
|
||||
def graph_contains_scope(graph: Graph, scope: [str, tuple]):
|
||||
"""
|
||||
Checks whether the graph contains node(s) which name starts with "scope" string.
|
||||
Checks whether the graph contains node(s) which name includes "scope" string.
|
||||
:param graph: graph to check
|
||||
:param scope: string defining the scope
|
||||
:param scope: string or tuple with strings defining the scope
|
||||
:return: the result of the check (True/False)
|
||||
"""
|
||||
if scope[-1] != '/':
|
||||
scope += '/'
|
||||
return any([node.soft_get('name').startswith(scope) for node in graph.get_op_nodes()])
|
||||
if type(scope) is str:
|
||||
return any([node.soft_get('name').find(scope) != -1 for node in graph.get_op_nodes()])
|
||||
else:
|
||||
return any([graph_contains_scope(graph, s) for s in scope])
|
||||
|
Loading…
Reference in New Issue
Block a user