Preserve inputs and outputs order for TensorFlow 2 and MXNet models (#9683)

* Preserve outputs order TF.

* Preserve of input/output indices for MxNet.

* Small fix.

* Added check.

* Small fix.

* Corrected Keras model importing.

* Fixed Keras model loading.

* Small correction.

* Corrected model loading.

* Small fix.

* Comment corrected.

* Removed unnecessary import.
This commit is contained in:
Anastasia Popova
2022-01-25 10:13:05 +03:00
committed by GitHub
parent 6f60a7d8f0
commit d27bbb4bdd
6 changed files with 46 additions and 21 deletions

View File

@@ -93,6 +93,8 @@ def symbol2nx(graph, model_nodes, model_params, input_names: str = ''):
else:
input_names = input_names.split(',')
graph.inputs_order = input_names
rnn_states = init_rnn_states(model_nodes)
names_rnn_states = list(rnn_states.keys())
@@ -125,6 +127,8 @@ def symbol2nx(graph, model_nodes, model_params, input_names: str = ''):
output_ids = [index_node_keys[node_id] for node_id in set(range(len(model_nodes))) - used_indices_set]
graph.outputs_order = output_ids
# Tensor names information corresponding to a node is stored on outgoing edges.
# As output nodes do not have outgoing edges, fake outputs are required. In the following code
# for each output Identity node is added, and tensor name for the output is kept

View File

@@ -190,7 +190,7 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo
try:
if graph_file_name and not meta_graph_file and not checkpoint:
# frozen graph
return read_file_to_graph_def(graph_def, graph_file_name, is_binary), variables_values, 'tf'
return read_file_to_graph_def(graph_def, graph_file_name, is_binary), variables_values, 'tf', None
if graph_file_name and not meta_graph_file and checkpoint:
# inference graph and checkpoint
graph_def = read_file_to_graph_def(graph_def, graph_file_name, is_binary)
@@ -201,7 +201,7 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo
graph_def, variables_values = freeze_checkpoints(graph_def=graph_def, checkpoint_dir=checkpoint,
output_node_names=outputs)
# we are sure that checkpoint is existing file or directory due to cli_parser configuration
return graph_def, variables_values, 'tf'
return graph_def, variables_values, 'tf', None
if not graph_file_name and meta_graph_file:
meta_graph_file = deducing_metagraph_path(meta_graph_file)
input_meta_graph_def = read_file_to_graph_def(tf_v1.MetaGraphDef(), meta_graph_file, is_binary)
@@ -212,16 +212,19 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo
outputs = get_output_node_names_list(input_meta_graph_def.graph_def, user_output_node_names_list)
graph_def = tf_v1.graph_util.convert_variables_to_constants(sess, input_meta_graph_def.graph_def,
outputs)
return graph_def, variables_values, 'tf'
return graph_def, variables_values, 'tf', None
if model_dir:
# saved model directory
try:
env_setup = get_environment_setup("tf")
# enable eager execution temporarily while TensorFlow 2 model is being loaded
tf_v1.enable_eager_execution()
# code to extract GraphDef for TF 2.0 SavedModel format
# tf.saved_model.load function throws TypeError for TF 1.x SavedModel format in case TF 1.x installed
imported = tf.saved_model.load(model_dir, saved_model_tags) # pylint: disable=E1120
# Code to extract Keras model.
# tf.keras.models.load_model function throws TypeError,KeyError or IndexError
# for TF 1.x SavedModel format in case TF 1.x installed
imported = tf.keras.models.load_model(model_dir, compile=False)
# to get a signature by key throws KeyError for TF 1.x SavedModel format in case TF 2.x installed
concrete_func = imported.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
# the aggressive inlining parameter needs to freeze a table of embeddings for Keras Embedding operation
@@ -236,8 +239,22 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo
graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
# disable eager execution since next steps are executed with a graph in non-eager mode
tf_v1.disable_eager_execution()
return graph_def, variables_values, 'tf2'
except (TypeError, KeyError):
input_names = []
if hasattr(imported, 'inputs'):
# Extract tensor names order from Keras model
input_names = [tensor.name for tensor in imported.inputs]
# After model freezing output tensor names are changing and recieve "Func/PartitionedCall" prefix,
# so output_names from saved_model cannot be used. Here tensor names from frozen graph are used,
# as TF adds indexed Identity nodes during freezing to each output, so this indexing is used for
# order alignment.
output_names = [tensor.name for tensor in frozen_func.outputs]
inputs_outputs_order = (input_names, output_names)
return graph_def, variables_values, 'tf2', inputs_outputs_order
except:
# disable eager execution since TensorFlow 1 model is handled
tf_v1.disable_eager_execution()
# code to extract GraphDef for TF 1.0 SavedModel format
@@ -246,9 +263,7 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo
meta_graph_def = tf_v1.saved_model.loader.load(sess, tags, model_dir)
outputs = get_output_node_names_list(meta_graph_def.graph_def, user_output_node_names_list)
graph_def = tf_v1.graph_util.convert_variables_to_constants(sess, meta_graph_def.graph_def, outputs)
return graph_def, variables_values, 'tf'
except Exception as e:
raise FrameworkError('SavedModel format load failure: {}', e) from e
return graph_def, variables_values, 'tf', None
except Exception as e:
raise FrameworkError('Cannot load input model: {}', e) from e
raise Error("Unknown configuration of input model parameters")

View File

@@ -42,13 +42,19 @@ class TFLoader(Loader):
log.info('Loading library "{}" with custom operations'.format(library))
tf_v1.load_op_library(library)
graph_def, variables_values, framework = load_tf_graph_def(graph_file_name=argv.input_model,
is_binary=not argv.input_model_is_text,
checkpoint=argv.input_checkpoint,
user_output_node_names_list=argv.output,
model_dir=argv.saved_model_dir,
meta_graph_file=argv.input_meta_graph,
saved_model_tags=argv.saved_model_tags)
graph_def, variables_values, framework, inputs_outputs_order = load_tf_graph_def(
graph_file_name=argv.input_model,
is_binary=not argv.input_model_is_text,
checkpoint=argv.input_checkpoint,
user_output_node_names_list=argv.output,
model_dir=argv.saved_model_dir,
meta_graph_file=argv.input_meta_graph,
saved_model_tags=argv.saved_model_tags)
if inputs_outputs_order is not None and isinstance(inputs_outputs_order, tuple):
graph.inputs_order = inputs_outputs_order[0]
graph.outputs_order = inputs_outputs_order[1]
send_framework_info(framework)
try:

View File

@@ -31,7 +31,7 @@ def convert(filename: str, is_text: bool):
new_ext = ".pbtxt" if is_text else ".pb"
head, tail = os.path.split(os.path.abspath(filename))
print("Convert: {} \n to: {}".format(filename, os.path.join(head, tail + new_ext)))
graph_def, _ = load_tf_graph_def(graph_file_name=filename, is_binary=is_text)
graph_def, _, _, _ = load_tf_graph_def(graph_file_name=filename, is_binary=is_text)
tf_v1.import_graph_def(graph_def, name='')
tf_v1.train.write_graph(graph_def, head, tail + new_ext, as_text=is_text)

View File

@@ -71,7 +71,7 @@ def main():
print("[ ERROR ] Both keys were provided --input_model and --input_dir. Please, provide only one of them")
sys.exit(1)
tags = argv.saved_model_tags.split(",")
graph_def, _, _ = load_tf_graph_def(graph_file_name=argv.input_model, is_binary=not argv.text,
graph_def, _, _, _ = load_tf_graph_def(graph_file_name=argv.input_model, is_binary=not argv.text,
checkpoint=argv.input_checkpoint,
model_dir=argv.saved_model_dir, saved_model_tags=tags)
summary = summarize_graph(graph_def)

View File

@@ -15,7 +15,7 @@ pbtxt = 'node{name:"Placeholder"op:"Placeholder"attr{key:"dtype"value{type:DT_FL
class TestingSummarizeGraph(unittest.TestCase):
def test_summarize_graph(self):
with patch('openvino.tools.mo.front.tf.loader.open', mock_open(read_data=pbtxt)) as m:
graph_def, _, _ = load_tf_graph_def('path', False)
graph_def, _, _, _ = load_tf_graph_def('path', False)
summary = summarize_graph(graph_def)
self.assertEqual(len(summary['outputs']), 1)
self.assertEqual(summary['outputs'][0], 'Output/Identity')