Preserve inputs and outputs order for TensorFlow 2 and MXNet models (#9683)
* Preserve outputs order TF. * Preserve of input/output indices for MxNet. * Small fix. * Added check. * Small fix. * Corrected Keras model importing. * Fixed Keras model loading. * Small correction. * Corrected model loading. * Small fix. * Comment corrected. * Removed unnecessary import.
This commit is contained in:
@@ -93,6 +93,8 @@ def symbol2nx(graph, model_nodes, model_params, input_names: str = ''):
|
||||
else:
|
||||
input_names = input_names.split(',')
|
||||
|
||||
graph.inputs_order = input_names
|
||||
|
||||
rnn_states = init_rnn_states(model_nodes)
|
||||
names_rnn_states = list(rnn_states.keys())
|
||||
|
||||
@@ -125,6 +127,8 @@ def symbol2nx(graph, model_nodes, model_params, input_names: str = ''):
|
||||
|
||||
output_ids = [index_node_keys[node_id] for node_id in set(range(len(model_nodes))) - used_indices_set]
|
||||
|
||||
graph.outputs_order = output_ids
|
||||
|
||||
# Tensor names information corresponding to a node is stored on outgoing edges.
|
||||
# As output nodes do not have outgoing edges, fake outputs are required. In the following code
|
||||
# for each output Identity node is added, and tensor name for the output is kept
|
||||
|
||||
@@ -190,7 +190,7 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo
|
||||
try:
|
||||
if graph_file_name and not meta_graph_file and not checkpoint:
|
||||
# frozen graph
|
||||
return read_file_to_graph_def(graph_def, graph_file_name, is_binary), variables_values, 'tf'
|
||||
return read_file_to_graph_def(graph_def, graph_file_name, is_binary), variables_values, 'tf', None
|
||||
if graph_file_name and not meta_graph_file and checkpoint:
|
||||
# inference graph and checkpoint
|
||||
graph_def = read_file_to_graph_def(graph_def, graph_file_name, is_binary)
|
||||
@@ -201,7 +201,7 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo
|
||||
graph_def, variables_values = freeze_checkpoints(graph_def=graph_def, checkpoint_dir=checkpoint,
|
||||
output_node_names=outputs)
|
||||
# we are sure that checkpoint is existing file or directory due to cli_parser configuration
|
||||
return graph_def, variables_values, 'tf'
|
||||
return graph_def, variables_values, 'tf', None
|
||||
if not graph_file_name and meta_graph_file:
|
||||
meta_graph_file = deducing_metagraph_path(meta_graph_file)
|
||||
input_meta_graph_def = read_file_to_graph_def(tf_v1.MetaGraphDef(), meta_graph_file, is_binary)
|
||||
@@ -212,16 +212,19 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo
|
||||
outputs = get_output_node_names_list(input_meta_graph_def.graph_def, user_output_node_names_list)
|
||||
graph_def = tf_v1.graph_util.convert_variables_to_constants(sess, input_meta_graph_def.graph_def,
|
||||
outputs)
|
||||
return graph_def, variables_values, 'tf'
|
||||
return graph_def, variables_values, 'tf', None
|
||||
if model_dir:
|
||||
# saved model directory
|
||||
try:
|
||||
env_setup = get_environment_setup("tf")
|
||||
# enable eager execution temporarily while TensorFlow 2 model is being loaded
|
||||
tf_v1.enable_eager_execution()
|
||||
# code to extract GraphDef for TF 2.0 SavedModel format
|
||||
# tf.saved_model.load function throws TypeError for TF 1.x SavedModel format in case TF 1.x installed
|
||||
imported = tf.saved_model.load(model_dir, saved_model_tags) # pylint: disable=E1120
|
||||
|
||||
# Code to extract Keras model.
|
||||
# tf.keras.models.load_model function throws TypeError,KeyError or IndexError
|
||||
# for TF 1.x SavedModel format in case TF 1.x installed
|
||||
imported = tf.keras.models.load_model(model_dir, compile=False)
|
||||
|
||||
# to get a signature by key throws KeyError for TF 1.x SavedModel format in case TF 2.x installed
|
||||
concrete_func = imported.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
||||
# the aggressive inlining parameter needs to freeze a table of embeddings for Keras Embedding operation
|
||||
@@ -236,8 +239,22 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo
|
||||
graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
|
||||
# disable eager execution since next steps are executed with a graph in non-eager mode
|
||||
tf_v1.disable_eager_execution()
|
||||
return graph_def, variables_values, 'tf2'
|
||||
except (TypeError, KeyError):
|
||||
|
||||
input_names = []
|
||||
if hasattr(imported, 'inputs'):
|
||||
# Extract tensor names order from Keras model
|
||||
input_names = [tensor.name for tensor in imported.inputs]
|
||||
|
||||
# After model freezing output tensor names are changing and recieve "Func/PartitionedCall" prefix,
|
||||
# so output_names from saved_model cannot be used. Here tensor names from frozen graph are used,
|
||||
# as TF adds indexed Identity nodes during freezing to each output, so this indexing is used for
|
||||
# order alignment.
|
||||
output_names = [tensor.name for tensor in frozen_func.outputs]
|
||||
|
||||
inputs_outputs_order = (input_names, output_names)
|
||||
|
||||
return graph_def, variables_values, 'tf2', inputs_outputs_order
|
||||
except:
|
||||
# disable eager execution since TensorFlow 1 model is handled
|
||||
tf_v1.disable_eager_execution()
|
||||
# code to extract GraphDef for TF 1.0 SavedModel format
|
||||
@@ -246,9 +263,7 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo
|
||||
meta_graph_def = tf_v1.saved_model.loader.load(sess, tags, model_dir)
|
||||
outputs = get_output_node_names_list(meta_graph_def.graph_def, user_output_node_names_list)
|
||||
graph_def = tf_v1.graph_util.convert_variables_to_constants(sess, meta_graph_def.graph_def, outputs)
|
||||
return graph_def, variables_values, 'tf'
|
||||
except Exception as e:
|
||||
raise FrameworkError('SavedModel format load failure: {}', e) from e
|
||||
return graph_def, variables_values, 'tf', None
|
||||
except Exception as e:
|
||||
raise FrameworkError('Cannot load input model: {}', e) from e
|
||||
raise Error("Unknown configuration of input model parameters")
|
||||
|
||||
@@ -42,13 +42,19 @@ class TFLoader(Loader):
|
||||
log.info('Loading library "{}" with custom operations'.format(library))
|
||||
tf_v1.load_op_library(library)
|
||||
|
||||
graph_def, variables_values, framework = load_tf_graph_def(graph_file_name=argv.input_model,
|
||||
is_binary=not argv.input_model_is_text,
|
||||
checkpoint=argv.input_checkpoint,
|
||||
user_output_node_names_list=argv.output,
|
||||
model_dir=argv.saved_model_dir,
|
||||
meta_graph_file=argv.input_meta_graph,
|
||||
saved_model_tags=argv.saved_model_tags)
|
||||
graph_def, variables_values, framework, inputs_outputs_order = load_tf_graph_def(
|
||||
graph_file_name=argv.input_model,
|
||||
is_binary=not argv.input_model_is_text,
|
||||
checkpoint=argv.input_checkpoint,
|
||||
user_output_node_names_list=argv.output,
|
||||
model_dir=argv.saved_model_dir,
|
||||
meta_graph_file=argv.input_meta_graph,
|
||||
saved_model_tags=argv.saved_model_tags)
|
||||
|
||||
if inputs_outputs_order is not None and isinstance(inputs_outputs_order, tuple):
|
||||
graph.inputs_order = inputs_outputs_order[0]
|
||||
graph.outputs_order = inputs_outputs_order[1]
|
||||
|
||||
send_framework_info(framework)
|
||||
|
||||
try:
|
||||
|
||||
@@ -31,7 +31,7 @@ def convert(filename: str, is_text: bool):
|
||||
new_ext = ".pbtxt" if is_text else ".pb"
|
||||
head, tail = os.path.split(os.path.abspath(filename))
|
||||
print("Convert: {} \n to: {}".format(filename, os.path.join(head, tail + new_ext)))
|
||||
graph_def, _ = load_tf_graph_def(graph_file_name=filename, is_binary=is_text)
|
||||
graph_def, _, _, _ = load_tf_graph_def(graph_file_name=filename, is_binary=is_text)
|
||||
tf_v1.import_graph_def(graph_def, name='')
|
||||
tf_v1.train.write_graph(graph_def, head, tail + new_ext, as_text=is_text)
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ def main():
|
||||
print("[ ERROR ] Both keys were provided --input_model and --input_dir. Please, provide only one of them")
|
||||
sys.exit(1)
|
||||
tags = argv.saved_model_tags.split(",")
|
||||
graph_def, _, _ = load_tf_graph_def(graph_file_name=argv.input_model, is_binary=not argv.text,
|
||||
graph_def, _, _, _ = load_tf_graph_def(graph_file_name=argv.input_model, is_binary=not argv.text,
|
||||
checkpoint=argv.input_checkpoint,
|
||||
model_dir=argv.saved_model_dir, saved_model_tags=tags)
|
||||
summary = summarize_graph(graph_def)
|
||||
|
||||
@@ -15,7 +15,7 @@ pbtxt = 'node{name:"Placeholder"op:"Placeholder"attr{key:"dtype"value{type:DT_FL
|
||||
class TestingSummarizeGraph(unittest.TestCase):
|
||||
def test_summarize_graph(self):
|
||||
with patch('openvino.tools.mo.front.tf.loader.open', mock_open(read_data=pbtxt)) as m:
|
||||
graph_def, _, _ = load_tf_graph_def('path', False)
|
||||
graph_def, _, _, _ = load_tf_graph_def('path', False)
|
||||
summary = summarize_graph(graph_def)
|
||||
self.assertEqual(len(summary['outputs']), 1)
|
||||
self.assertEqual(summary['outputs'][0], 'Output/Identity')
|
||||
|
||||
Reference in New Issue
Block a user