Enable eager execution while TensorFlow 2 model is loaded (#1945)
This commit is contained in:
parent
70839c1663
commit
9d5a6cff70
@ -226,6 +226,8 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo
|
||||
if model_dir:
|
||||
# saved model directory
|
||||
try:
|
||||
# enable eager execution temporarily while TensorFlow 2 model is being loaded
|
||||
tf_v1.enable_eager_execution()
|
||||
# code to extract GraphDef for TF 2.0 SavedModel format
|
||||
# tf.saved_model.load function throws TypeError for TF 1.x SavedModel format in case TF 1.x installed
|
||||
imported = tf.saved_model.load(model_dir, saved_model_tags) # pylint: disable=E1120
|
||||
@ -233,8 +235,12 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo
|
||||
concrete_func = imported.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
||||
frozen_func = convert_variables_to_constants_v2(concrete_func, lower_control_flow=False) # pylint: disable=E1123
|
||||
graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
|
||||
# disable eager execution since next steps are executed with a graph in non-eager mode
|
||||
tf_v1.disable_eager_execution()
|
||||
return graph_def, variables_values
|
||||
except (TypeError, KeyError):
|
||||
# disable eager execution since TensorFlow 1 model is handled
|
||||
tf_v1.disable_eager_execution()
|
||||
# code to extract GraphDef for TF 1.0 SavedModel format
|
||||
tags = saved_model_tags if saved_model_tags is not None else [tf_v1.saved_model.tag_constants.SERVING]
|
||||
with tf_v1.Session() as sess:
|
||||
|
Loading…
Reference in New Issue
Block a user