Support python 3.8 by the Model Optimizer tool in default configuration (#2078)

* Support python 3.8 by the Model Optimizer tool in default configuration

* Fix after review #1

* Fix after the second round review
This commit is contained in:
Roman Kazantsev 2020-09-09 08:34:43 +03:00 committed by GitHub
parent 5ad4811793
commit 82e15a5a64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 118 additions and 18 deletions

View File

@ -44,17 +44,24 @@ def check_python_version():
return 1
def parse_versions_list(required_fw_versions, version_list):
def parse_and_filter_versions_list(required_fw_versions, version_list, env_setup):
"""
Please do not add parameter type annotations (param:type).
Because we import this file while checking Python version.
Python 2.x will fail with no clear message on type annotations.
Parsing requirements versions
Parsing requirements versions for a dependency and filtering out requirements that
satisfy environment setup such as python version.
if environment version (python_version, etc.) is satisfied
:param required_fw_versions: String with fw versions from requirements file
:param version_list: List for append
:param env_setup: a dictionary with environment setup
:return: list of tuples of strings like (name_of_module, sign, version)
Examples of required_fw_versions:
'tensorflow>=1.15.2,<2.0; python_version < "3.8"'
'tensorflow>=2.0'
Returned object is:
[('tensorflow', '>=', '1.2.0'), ('networkx', '==', '2.1'), ('numpy', None, None)]
"""
@ -62,26 +69,57 @@ def parse_versions_list(required_fw_versions, version_list):
line = required_fw_versions.strip('\n')
line = line.strip(' ')
if line == '':
return []
splited_versions_by_conditions = re.split(r"==|>=|<=|>|<", line)
return version_list
splited_requirement = line.split(";")
# check environment marker
if len(splited_requirement) > 1:
env_req = splited_requirement[1]
splited_env_req = re.split(r"==|>=|<=|>|<", env_req)
splited_env_req = [l.strip(',') for l in splited_env_req]
env_marker = splited_env_req[0].strip(' ')
if env_marker == 'python_version' and env_marker in env_setup:
installed_python_version = env_setup['python_version']
env_req_version_list = []
splited_required_versions = re.split(r",", env_req)
for i, l in enumerate(splited_required_versions):
for comparison in ['==', '>=', '<=', '<', '>']:
if comparison in l:
required_version = splited_env_req[i + 1].strip(' ').replace('"', '')
env_req_version_list.append((env_marker, comparison, required_version))
break
not_satisfied_list = []
for name, key, required_version in env_req_version_list:
version_check(name, installed_python_version, required_version,
key, not_satisfied_list, 0)
if len(not_satisfied_list) > 0:
# this python_version requirement is not satisfied to required environment
# and requirement for a dependency will be skipped
return version_list
else:
log.error("{} is unsupported environment marker and it will be ignored".format(env_marker),
extra={'is_warning': True})
# parse a requirement for a dependency
requirement = splited_requirement[0]
splited_versions_by_conditions = re.split(r"==|>=|<=|>|<", requirement)
splited_versions_by_conditions = [l.strip(',') for l in splited_versions_by_conditions]
if len(splited_versions_by_conditions) == 0:
return []
return version_list
if len(splited_versions_by_conditions) == 1:
version_list.append((splited_versions_by_conditions[0], None, None))
else:
splited_required_versions= re.split(r",", line)
splited_required_versions= re.split(r",", requirement)
for i, l in enumerate(splited_required_versions):
comparisons = ['==', '>=', '<=', '<', '>']
for comparison in comparisons:
for comparison in ['==', '>=', '<=', '<', '>']:
if comparison in l:
version_list.append((splited_versions_by_conditions[0], comparison, splited_versions_by_conditions[i + 1]))
break
return version_list
def get_module_version_list_from_file(file_name):
def get_module_version_list_from_file(file_name, env_setup):
"""
Please do not add parameter type annotations (param:type).
Because we import this file while checking Python version.
@ -89,6 +127,7 @@ def get_module_version_list_from_file(file_name):
Reads file with requirements
:param file_name: Name of the requirements file
:param env_setup: a dictionary with environment setup elements
:return: list of tuples of strings like (name_of_module, sign, version)
File content example:
@ -102,7 +141,7 @@ def get_module_version_list_from_file(file_name):
req_dict = list()
with open(file_name) as f:
for line in f:
req_dict = parse_versions_list(line, req_dict)
req_dict = parse_and_filter_versions_list(line, req_dict, env_setup)
return req_dict
@ -113,7 +152,7 @@ def version_check(name, installed_v, required_v, sign, not_satisfied_v, exit_cod
Python 2.x will fail with no clear message on type annotations.
Evaluates comparison of installed and required versions according to requirements file of one module.
If installed version does not satisfy requirements appends this module to not_stisfied_v list.
If installed version does not satisfy requirements appends this module to not_satisfied_v list.
:param name: module name
:param installed_v: installed version of module
:param required_v: required version of module
@ -146,6 +185,25 @@ def version_check(name, installed_v, required_v, sign, not_satisfied_v, exit_cod
return exit_code
def get_environment_setup():
"""
Get environment setup such as Python version, TensorFlow version
:return: a dictionary of environment variables
"""
env_setup = dict()
python_version = "{}.{}.{}".format(sys.version_info.major,
sys.version_info.minor,
sys.version_info.micro)
env_setup['python_version'] = python_version
try:
exec("import tensorflow")
env_setup['tensorflow'] = sys.modules["tensorflow"].__version__
exec("del tensorflow")
except (AttributeError, ImportError):
pass
return env_setup
def check_requirements(framework=None):
"""
Please do not add parameter type annotations (param:type).
@ -158,13 +216,20 @@ def check_requirements(framework=None):
:param framework: framework name
:return: exit code (0 - execution successful, 1 - error)
"""
env_setup = get_environment_setup()
if framework is None:
framework_suffix = ""
elif framework == "tf":
if "tensorflow" in env_setup and env_setup["tensorflow"] >= LooseVersion("2.0.0"):
framework_suffix = "_tf2"
else:
framework_suffix = "_tf"
else:
framework_suffix = "_{}".format(framework)
file_name = "requirements{}.txt".format(framework_suffix)
requirements_file = os.path.realpath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, file_name))
requirements_list = get_module_version_list_from_file(requirements_file)
requirements_list = get_module_version_list_from_file(requirements_file, env_setup)
not_satisfied_versions = []
exit_code = 0
for name, key, required_version in requirements_list:

View File

@ -18,7 +18,7 @@ import unittest
import unittest.mock as mock
from unittest.mock import mock_open
from mo.utils.versions_checker import get_module_version_list_from_file, parse_versions_list
from mo.utils.versions_checker import get_module_version_list_from_file, parse_and_filter_versions_list
class TestingVersionsChecker(unittest.TestCase):
@ -30,18 +30,51 @@ class TestingVersionsChecker(unittest.TestCase):
ref_list =[('mxnet', '>=', '1.0.0'), ('mxnet', '<=', '1.3.1'),
('networkx', '>=', '1.11'),
('numpy', '==', '1.12.0'), ('defusedxml', '<=', '0.5.0')]
version_list = get_module_version_list_from_file('mock_file')
version_list = get_module_version_list_from_file('mock_file', {})
self.assertEqual(len(version_list), 5)
for i, version_dict in enumerate(version_list):
self.assertTupleEqual(ref_list[i], version_dict)
@mock.patch('builtins.open', new_callable=mock_open, create=True)
def test_get_module_version_list_from_file2(self, mock_open):
mock_open.return_value.__enter__ = mock_open
mock_open.return_value.__iter__ = mock.Mock(
return_value=iter(['tensorflow>=1.15.2,<2.0; python_version < "3.8"',
'tensorflow>=2.0; python_version >= "3.8"',
'numpy==1.12.0',
'defusedxml<=0.5.0']))
ref_list =[('tensorflow', '>=', '1.15.2'),
('tensorflow', '<', '2.0'),
('numpy', '==', '1.12.0'),
('defusedxml', '<=', '0.5.0')]
version_list = get_module_version_list_from_file('mock_file', {'python_version': '3.7.0'})
self.assertEqual(len(version_list), 4)
for i, version_dict in enumerate(version_list):
self.assertTupleEqual(ref_list[i], version_dict)
@mock.patch('builtins.open', new_callable=mock_open, create=True)
def test_get_module_version_list_from_file3(self, mock_open):
mock_open.return_value.__enter__ = mock_open
mock_open.return_value.__iter__ = mock.Mock(
return_value=iter(['tensorflow>=1.15.2,<2.0; python_version < "3.8"',
'tensorflow>=2.0; python_version >= "3.8"',
'numpy==1.12.0',
'defusedxml<=0.5.0']))
ref_list =[('tensorflow', '>=', '2.0'),
('numpy', '==', '1.12.0'),
('defusedxml', '<=', '0.5.0')]
version_list = get_module_version_list_from_file('mock_file', {'python_version': '3.8.1'})
self.assertEqual(len(version_list), 3)
for i, version_dict in enumerate(version_list):
self.assertTupleEqual(ref_list[i], version_dict)
@mock.patch('builtins.open', new_callable=mock_open, create=True)
def test_get_module_version_list_from_file_with_fw_name(self, mock_open):
mock_open.return_value.__enter__ = mock_open
mock_open.return_value.__iter__ = mock.Mock(
return_value=iter(['mxnet']))
ref_list = [('mxnet', None, None)]
version_list = get_module_version_list_from_file('mock_file')
version_list = get_module_version_list_from_file('mock_file', {})
self.assertEqual(len(version_list), 1)
for i, version_dict in enumerate(version_list):
self.assertTupleEqual(ref_list[i], version_dict)
@ -49,7 +82,7 @@ class TestingVersionsChecker(unittest.TestCase):
def test_append_version_list(self):
v1 = 'mxnet>=1.0.0,<=1.3.1'
req_list = list()
parse_versions_list(v1, req_list)
parse_and_filter_versions_list(v1, req_list, {})
ref_list = [('mxnet', '>=', '1.0.0'),
('mxnet', '<=', '1.3.1')]
for i, v in enumerate(req_list):

View File

@ -1,4 +1,5 @@
tensorflow>=1.15.2,<2.0
tensorflow>=1.15.2,<2.0; python_version < "3.8"
tensorflow>=2.0; python_version >= "3.8"
mxnet>=1.0.0,<=1.5.1
networkx>=1.11
numpy>=1.13.0

View File

@ -1,4 +1,5 @@
tensorflow>=1.15.2,<2.0
tensorflow>=1.15.2,<2.0; python_version < "3.8"
tensorflow>=2.0; python_version >= "3.8"
networkx>=1.11
numpy>=1.13.0
test-generator==0.1.1