[MO] Updating MO to detect TF 2.X OD API models (#6983)
* updated FasterRCNN and SSD analysis patterns * updated tf od api conditions * updated ssd patterns * added more ssd topologies * move preprocessor to tf od api condition * update TF OD API conditions * refactoring * specify data type
This commit is contained in:
parent
754ee2eb1a
commit
9d53b3536d
@ -16,62 +16,73 @@ class TensorFlowObjectDetectionAPIAnalysis(AnalyzeAction):
|
|||||||
"""
|
"""
|
||||||
graph_condition = [lambda graph: graph.graph['fw'] == 'tf']
|
graph_condition = [lambda graph: graph.graph['fw'] == 'tf']
|
||||||
|
|
||||||
model_scopes = [('MaskRCNN', ['Preprocessor',
|
|
||||||
'FirstStageFeatureExtractor',
|
|
||||||
'SecondStageFeatureExtractor',
|
|
||||||
'SecondStageBoxPredictor',
|
|
||||||
'SecondStageBoxPredictor_1',
|
|
||||||
'SecondStageFeatureExtractor_1',
|
|
||||||
]),
|
|
||||||
('RFCN', ['Preprocessor',
|
|
||||||
'FirstStageFeatureExtractor',
|
|
||||||
'SecondStageFeatureExtractor',
|
|
||||||
'SecondStageBoxPredictor',
|
|
||||||
'SecondStageBoxPredictor/map',
|
|
||||||
'SecondStageBoxPredictor/map_1',
|
|
||||||
'SecondStagePostprocessor',
|
|
||||||
]),
|
|
||||||
('FasterRCNN', ['Preprocessor',
|
|
||||||
'FirstStageFeatureExtractor',
|
|
||||||
'SecondStageFeatureExtractor',
|
|
||||||
'SecondStageBoxPredictor',
|
|
||||||
'SecondStagePostprocessor',
|
|
||||||
]),
|
|
||||||
('SSD', ['Preprocessor',
|
|
||||||
'FeatureExtractor',
|
|
||||||
'Postprocessor',
|
|
||||||
]),
|
|
||||||
]
|
|
||||||
|
|
||||||
file_patterns = {'MaskRCNN': 'mask_rcnn_support.*\\.json',
|
file_patterns = {'MaskRCNN': 'mask_rcnn_support.*\\.json',
|
||||||
'RFCN': 'rfcn_support.*\\.json',
|
'RFCN': 'rfcn_support.*\\.json',
|
||||||
'FasterRCNN': 'faster_rcnn_support.*\\.json',
|
'FasterRCNN': 'faster_rcnn_support.*\\.json',
|
||||||
'SSD': 'ssd.*_support.*\\.json',
|
'SSD': 'ssd.*_support.*\\.json',
|
||||||
}
|
}
|
||||||
|
model_scopes = {'MaskRCNN': (['FirstStageFeatureExtractor',
|
||||||
|
'SecondStageFeatureExtractor',
|
||||||
|
'SecondStageBoxPredictor',
|
||||||
|
'SecondStageBoxPredictor_1',
|
||||||
|
'SecondStageFeatureExtractor_1',
|
||||||
|
],),
|
||||||
|
'RFCN': (['FirstStageFeatureExtractor',
|
||||||
|
'SecondStageFeatureExtractor',
|
||||||
|
'SecondStageBoxPredictor',
|
||||||
|
'SecondStageBoxPredictor/map',
|
||||||
|
'SecondStageBoxPredictor/map_1',
|
||||||
|
'SecondStagePostprocessor',
|
||||||
|
],),
|
||||||
|
'FasterRCNN': (['FirstStageFeatureExtractor',
|
||||||
|
'SecondStageFeatureExtractor',
|
||||||
|
'SecondStageBoxPredictor',
|
||||||
|
'SecondStagePostprocessor',
|
||||||
|
],
|
||||||
|
['FirstStageRPNFeatures',
|
||||||
|
'FirstStageBoxPredictor',
|
||||||
|
'SecondStagePostprocessor',
|
||||||
|
'mask_rcnn_keras_box_predictor',
|
||||||
|
],),
|
||||||
|
'SSD': ([('FeatureExtractor', 'ssd_mobile_net_v2keras_feature_extractor',
|
||||||
|
'ssd_mobile_net_v1fpn_keras_feature_extractor',
|
||||||
|
'ssd_mobile_net_v2fpn_keras_feature_extractor', 'ResNet50V1_FPN', 'ResNet101V1_FPN',
|
||||||
|
'ResNet152V1_FPN'
|
||||||
|
),
|
||||||
|
'Postprocessor']
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
def analyze(self, graph: Graph):
|
def analyze(self, graph: Graph):
|
||||||
if any([name not in graph.nodes() for name in ['image_tensor', 'detection_classes', 'detection_boxes',
|
tf_1_names = ['image_tensor', 'detection_classes', 'detection_boxes', 'detection_scores',
|
||||||
'detection_scores']]):
|
('Preprocessor', 'map')]
|
||||||
|
tf_1_cond = all([graph_contains_scope(graph, scope) for scope in tf_1_names])
|
||||||
|
|
||||||
|
tf_2_names = ['input_tensor', 'output_control_node', 'Identity_', ('Preprocessor', 'map')]
|
||||||
|
tf_2_cond = all([graph_contains_scope(graph, scope) for scope in tf_2_names])
|
||||||
|
|
||||||
|
if not tf_1_cond and not tf_2_cond:
|
||||||
log.debug('The model does not contain nodes that must exist in the TF OD API models')
|
log.debug('The model does not contain nodes that must exist in the TF OD API models')
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
for flavor, scopes in __class__.model_scopes:
|
for flavor, scopes_tuple in self.model_scopes.items():
|
||||||
if all([graph_contains_scope(graph, scope) for scope in scopes]):
|
for scopes in scopes_tuple:
|
||||||
result = dict()
|
if all([graph_contains_scope(graph, scope) for scope in scopes]):
|
||||||
result['flavor'] = flavor
|
result = dict()
|
||||||
result['mandatory_parameters'] = {'tensorflow_use_custom_operations_config':
|
result['flavor'] = flavor
|
||||||
files_by_pattern(get_mo_root_dir() + '/extensions/front/tf',
|
result['mandatory_parameters'] = {'tensorflow_use_custom_operations_config':
|
||||||
__class__.file_patterns[flavor],
|
files_by_pattern(get_mo_root_dir() + '/extensions/front/tf',
|
||||||
add_prefix=True),
|
__class__.file_patterns[flavor],
|
||||||
'tensorflow_object_detection_api_pipeline_config': None,
|
add_prefix=True),
|
||||||
}
|
'tensorflow_object_detection_api_pipeline_config': None,
|
||||||
message = "Your model looks like TensorFlow Object Detection API Model.\n" \
|
}
|
||||||
"Check if all parameters are specified:\n" \
|
message = "Your model looks like TensorFlow Object Detection API Model.\n" \
|
||||||
"\t--tensorflow_use_custom_operations_config\n" \
|
"Check if all parameters are specified:\n" \
|
||||||
"\t--tensorflow_object_detection_api_pipeline_config\n" \
|
"\t--transformations_config\n" \
|
||||||
"\t--input_shape (optional)\n" \
|
"\t--tensorflow_object_detection_api_pipeline_config\n" \
|
||||||
"\t--reverse_input_channels (if you convert a model to use with the Inference Engine sample applications)\n" \
|
"\t--input_shape (optional)\n" \
|
||||||
"Detailed information about conversion of this model can be found at\n" \
|
"\t--reverse_input_channels (if you convert a model to use with the Inference Engine sample applications)\n" \
|
||||||
"https://docs.openvinotoolkit.org/latest/_docs_MO_DG_prepare_model_convert_model_tf_specific_Convert_Object_Detection_API_Models.html"
|
"Detailed information about conversion of this model can be found at\n" \
|
||||||
return {'model_type': {'TF_OD_API': result}}, message
|
"https://docs.openvinotoolkit.org/latest/_docs_MO_DG_prepare_model_convert_model_tf_specific_Convert_Object_Detection_API_Models.html"
|
||||||
|
return {'model_type': {'TF_OD_API': result}}, message
|
||||||
return None, None
|
return None, None
|
||||||
|
@ -112,13 +112,14 @@ class AnalysisCollectorAnchor(AnalyzeAction):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def graph_contains_scope(graph: Graph, scope: str):
|
def graph_contains_scope(graph: Graph, scope: [str, tuple]):
|
||||||
"""
|
"""
|
||||||
Checks whether the graph contains node(s) which name starts with "scope" string.
|
Checks whether the graph contains node(s) which name includes "scope" string.
|
||||||
:param graph: graph to check
|
:param graph: graph to check
|
||||||
:param scope: string defining the scope
|
:param scope: string or tuple with strings defining the scope
|
||||||
:return: the result of the check (True/False)
|
:return: the result of the check (True/False)
|
||||||
"""
|
"""
|
||||||
if scope[-1] != '/':
|
if type(scope) is str:
|
||||||
scope += '/'
|
return any([node.soft_get('name').find(scope) != -1 for node in graph.get_op_nodes()])
|
||||||
return any([node.soft_get('name').startswith(scope) for node in graph.get_op_nodes()])
|
else:
|
||||||
|
return any([graph_contains_scope(graph, s) for s in scope])
|
||||||
|
Loading…
Reference in New Issue
Block a user