Files
openvino/model-optimizer/extensions/load/tf/loader.py
Eugeny Volosenkov 38022c4cd6 Mo implementation for If with tf extractor (#6662)
* Add tf2.x impl for If

* Fix ir_engine

* Fix opset

* Fix BOM file

* Added new test

* Fix comments

* Add subgraph_utils

* Fix comments

* Fix transform

* code refactoring

* Fix description

* rewrite support for empty tensor in if

* added onnx extractor

* delete onnx_if

* fix bug with fake_outputs

* Fix test

* Fix control_flow and fix commentaries

* create method results_mapping_and_finding_fake_outputs(output_nodes_in_subgraph,
2021-08-19 10:13:21 +03:00

146 lines
6.8 KiB
Python

# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
try:
import tensorflow.compat.v1 as tf_v1
# disable eager execution of TensorFlow 2 environment immediately
tf_v1.disable_eager_execution()
except ImportError:
import tensorflow as tf_v1
try:
import tensorflow.contrib # pylint: disable=no-name-in-module,import-error
except:
pass # we try to import contrib for loading models that use contrib operations
import logging as log
from extensions.load.loader import Loader
from mo.front.common.register_custom_ops import check_for_duplicates
from mo.front.common.register_custom_ops import update_extractors_with_extensions
from mo.front.extractor import restore_edges, extract_node_attrs, remove_control_dependency_inputs, add_outputs_identity
from mo.front.tf.extractor import get_tf_edges, create_tf_edge, tf_op_extractor, tf_op_extractors
from mo.front.tf.loader import load_tf_graph_def, protobuf2nx
from mo.graph.graph import Graph, Node
from mo.utils import tensorboard_util
from mo.utils.error import Error
from mo.utils.telemetry_utils import send_op_names_info, send_shapes_info, send_framework_info
from mo.utils.utils import refer_to_faq_msg
class TFLoader(Loader):
enabled = True
run_not_recursively = True
def load(self, graph: Graph):
argv = graph.graph['cmd_params']
if argv.tensorflow_custom_layer_libraries:
libraries = argv.tensorflow_custom_layer_libraries.split(',')
for library in libraries:
log.info('Loading library "{}" with custom operations'.format(library))
tf_v1.load_op_library(library)
graph_def, variables_values, framework = load_tf_graph_def(graph_file_name=argv.input_model,
is_binary=not argv.input_model_is_text,
checkpoint=argv.input_checkpoint,
user_output_node_names_list=argv.output,
model_dir=argv.saved_model_dir,
meta_graph_file=argv.input_meta_graph,
saved_model_tags=argv.saved_model_tags)
send_framework_info(framework)
try:
tf_v1.import_graph_def(graph_def, name='')
except:
log.warning("TensorFlow post-processing of loaded model was unsuccessful. "
"This is an optional step that Model Optimizer performs for any input model but it is not usually "
"required for all models."
"It likely means that the original model is ill-formed. "
"Model Optimizer will continue converting this model.")
log.debug("Number of nodes in graph_def: {}".format(len(graph_def.node))) # pylint: disable=no-member
if argv.tensorboard_logdir:
tensorboard_util.dump_for_tensorboard(graph_def, argv.tensorboard_logdir)
update_extractors_with_extensions(tf_op_extractors)
try:
protobuf2nx(graph, graph_def)
except Exception as e:
raise Error(
'Cannot pre-process TensorFlow graph after reading from model file "{}". ' \
'File is corrupt or has unsupported format. Details: {}. ' +
refer_to_faq_msg(44),
argv.model_name,
str(e)
) from e
graph.__setattr__('name', argv.model_name)
# 'layout' parameter change may cause an issue in EltwiseInputReshape replacer
# and convert_nhwc_to_nchw(graph)
graph.graph['layout'] = 'NCHW' if argv.disable_nhwc_to_nchw else 'NHWC'
graph.graph['fw'] = 'tf'
graph.graph['variables_values'] = variables_values
del variables_values
used_tensors = restore_edges(graph, get_tf_edges)
# Tensor names information corresponding to a node is stored on outgoing edges.
# As output nodes do not have outgoing edges, fake outputs are required. In the following code
# for each output Identity node is added, and tensor name for the output is kept
# on (output, fake output) edge. After Result nodes adding transformation fake outputs
# are deleted from graph.
add_outputs_identity(graph, graph.nodes - used_tensors, lambda g, output, fake_node_name: g.add_edges_from([
create_tf_edge(output, fake_node_name, 0)]))
remove_control_dependency_inputs(graph)
graph.check_empty_graph('protobuf2nx. It may happen due to problems with loaded model')
extract_node_attrs(graph, lambda node: tf_op_extractor(node, check_for_duplicates(tf_op_extractors)))
# try to detect layout from the nodes of the graph. If there are no convolution nodes in N(D)HWC layout then we
# consider that the graph is in NCHW layout and no layout conversion should be performed
if not argv.disable_nhwc_to_nchw and not argv.silent and not graph_or_sub_graph_has_nhwc_ops(graph):
log.error('The TensorFlow model does not contain Convolution operations with N(D)HWC layout. Most likely '
'the model should be converted using additional "--disable_nhwc_to_nchw" command line parameter '
'which disables model layout conversion inside the Model Optimizer.', extra={'is_warning': True})
send_op_names_info(framework, graph)
send_shapes_info(framework, graph)
def is_node_layout_nhwc(node: Node):
"""
Check the layout attribute of specific operations and return True if any of them has layout NHWC.
:param node: Node to check
:return: Boolean result of the check
"""
if node.soft_get('op') in ["Conv2D", "DepthwiseConv2dNative", "Conv3D", "Conv2DBackpropInput",
"Conv3DBackpropInputV2"]:
if node.soft_get('layout') in ["NHWC", "NDHWC"]:
log.debug('Detected convolution node with NHWC layout: "{}"'.format(node.soft_get('name', node.id)))
return True
return False
def graph_or_sub_graph_has_nhwc_ops(graph: Graph):
"""
Checks that a graph or any sub-graph (inside Loop) operation contains nodes with NHWC layout.
:param graph: main graph to check
:return: Boolean result of the check
"""
NHWC_conv_detected = False
for node in graph.get_op_nodes():
if is_node_layout_nhwc(node):
NHWC_conv_detected = True
break
if node.has('sub_graphs'):
for sub_graph_name in node['sub_graphs']:
NHWC_conv_detected |= graph_or_sub_graph_has_nhwc_ops(node.soft_get(sub_graph_name))
return NHWC_conv_detected