Smal fix with tf env_setup (#5476)
* Smal fix with tf env_setup * Fix tf loader * Fix version checker
This commit is contained in:
parent
7fa93b226e
commit
fe5ca28b6e
@ -215,7 +215,7 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo
|
|||||||
if model_dir:
|
if model_dir:
|
||||||
# saved model directory
|
# saved model directory
|
||||||
try:
|
try:
|
||||||
env_setup = get_environment_setup()
|
env_setup = get_environment_setup("tf")
|
||||||
# enable eager execution temporarily while TensorFlow 2 model is being loaded
|
# enable eager execution temporarily while TensorFlow 2 model is being loaded
|
||||||
tf_v1.enable_eager_execution()
|
tf_v1.enable_eager_execution()
|
||||||
# code to extract GraphDef for TF 2.0 SavedModel format
|
# code to extract GraphDef for TF 2.0 SavedModel format
|
||||||
|
@ -196,9 +196,10 @@ def version_check(name, installed_v, required_v, sign, not_satisfied_v):
|
|||||||
not_satisfied_v.append((name, 'installed: {}'.format(installed_v), 'required: {} {}'.format(sign, required_v)))
|
not_satisfied_v.append((name, 'installed: {}'.format(installed_v), 'required: {} {}'.format(sign, required_v)))
|
||||||
|
|
||||||
|
|
||||||
def get_environment_setup():
|
def get_environment_setup(framework):
|
||||||
"""
|
"""
|
||||||
Get environment setup such as Python version, TensorFlow version
|
Get environment setup such as Python version, TensorFlow version
|
||||||
|
:param framework: framework name
|
||||||
:return: a dictionary of environment variables
|
:return: a dictionary of environment variables
|
||||||
"""
|
"""
|
||||||
env_setup = dict()
|
env_setup = dict()
|
||||||
@ -207,6 +208,7 @@ def get_environment_setup():
|
|||||||
sys.version_info.micro)
|
sys.version_info.micro)
|
||||||
env_setup['python_version'] = python_version
|
env_setup['python_version'] = python_version
|
||||||
try:
|
try:
|
||||||
|
if framework == 'tf':
|
||||||
exec("import tensorflow")
|
exec("import tensorflow")
|
||||||
env_setup['tensorflow'] = sys.modules["tensorflow"].__version__
|
env_setup['tensorflow'] = sys.modules["tensorflow"].__version__
|
||||||
exec("del tensorflow")
|
exec("del tensorflow")
|
||||||
@ -228,7 +230,7 @@ def check_requirements(framework=None):
|
|||||||
:param framework: framework name
|
:param framework: framework name
|
||||||
:return: exit code (0 - execution successful, 1 - error)
|
:return: exit code (0 - execution successful, 1 - error)
|
||||||
"""
|
"""
|
||||||
env_setup = get_environment_setup()
|
env_setup = get_environment_setup(framework)
|
||||||
if framework is None:
|
if framework is None:
|
||||||
framework_suffix = ""
|
framework_suffix = ""
|
||||||
elif framework == "tf":
|
elif framework == "tf":
|
||||||
|
Loading…
Reference in New Issue
Block a user