Fixed loader for TF2. (#9962)
This commit is contained in:
@@ -220,10 +220,13 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo
|
||||
# enable eager execution temporarily while TensorFlow 2 model is being loaded
|
||||
tf_v1.enable_eager_execution()
|
||||
|
||||
# Code to extract Keras model.
|
||||
# tf.keras.models.load_model function throws TypeError,KeyError or IndexError
|
||||
# for TF 1.x SavedModel format in case TF 1.x installed
|
||||
imported = tf.keras.models.load_model(model_dir, compile=False)
|
||||
try:
|
||||
# Code to extract Keras model.
|
||||
# tf.keras.models.load_model function throws TypeError,KeyError or IndexError
|
||||
# for TF 1.x SavedModel format in case TF 1.x installed
|
||||
imported = tf.keras.models.load_model(model_dir, compile=False)
|
||||
except:
|
||||
imported = tf.saved_model.load(model_dir, saved_model_tags) # pylint: disable=E1120
|
||||
|
||||
# to get a signature by key throws KeyError for TF 1.x SavedModel format in case TF 2.x installed
|
||||
concrete_func = imported.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
||||
|
||||
Reference in New Issue
Block a user