Smal fix with tf env_setup (#5476)

* Smal fix with tf env_setup

* Fix tf loader

* Fix version checker
This commit is contained in:
iliya mironov 2021-05-12 15:33:03 +03:00 committed by GitHub
parent 7fa93b226e
commit fe5ca28b6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 6 deletions

View File

@ -215,7 +215,7 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo
if model_dir:
# saved model directory
try:
env_setup = get_environment_setup()
env_setup = get_environment_setup("tf")
# enable eager execution temporarily while TensorFlow 2 model is being loaded
tf_v1.enable_eager_execution()
# code to extract GraphDef for TF 2.0 SavedModel format

View File

@ -196,9 +196,10 @@ def version_check(name, installed_v, required_v, sign, not_satisfied_v):
not_satisfied_v.append((name, 'installed: {}'.format(installed_v), 'required: {} {}'.format(sign, required_v)))
def get_environment_setup():
def get_environment_setup(framework):
"""
Get environment setup such as Python version, TensorFlow version
:param framework: framework name
:return: a dictionary of environment variables
"""
env_setup = dict()
@ -207,9 +208,10 @@ def get_environment_setup():
sys.version_info.micro)
env_setup['python_version'] = python_version
try:
exec("import tensorflow")
env_setup['tensorflow'] = sys.modules["tensorflow"].__version__
exec("del tensorflow")
if framework == 'tf':
exec("import tensorflow")
env_setup['tensorflow'] = sys.modules["tensorflow"].__version__
exec("del tensorflow")
except (AttributeError, ImportError):
pass
env_setup['sys_platform'] = sys.platform
@ -228,7 +230,7 @@ def check_requirements(framework=None):
:param framework: framework name
:return: exit code (0 - execution successful, 1 - error)
"""
env_setup = get_environment_setup()
env_setup = get_environment_setup(framework)
if framework is None:
framework_suffix = ""
elif framework == "tf":