Fixed loader for TF2. (#9962)

This commit is contained in:
Anastasia Popova
2022-01-28 13:23:22 +03:00
committed by GitHub
parent 4cd20425c3
commit a61655040f

View File

@@ -220,10 +220,13 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo
# enable eager execution temporarily while TensorFlow 2 model is being loaded
tf_v1.enable_eager_execution()
# Code to extract Keras model.
# tf.keras.models.load_model function throws TypeError,KeyError or IndexError
# for TF 1.x SavedModel format in case TF 1.x installed
imported = tf.keras.models.load_model(model_dir, compile=False)
try:
# Code to extract Keras model.
# tf.keras.models.load_model function throws TypeError,KeyError or IndexError
# for TF 1.x SavedModel format in case TF 1.x installed
imported = tf.keras.models.load_model(model_dir, compile=False)
except:
imported = tf.saved_model.load(model_dir, saved_model_tags) # pylint: disable=E1120
# to get a signature by key throws KeyError for TF 1.x SavedModel format in case TF 2.x installed
concrete_func = imported.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]