[MO] cli_parser fix when input contains substring with matching scale/mean values (#3146)

* fix MO cli_parser when input contains substring with matching scale/mean values

* some additions to cli_parser unit-tests

* fixed numpy array comparisons -- added assert_ prefix

* more general solution for mean/scale cli_parser, names with only digit values are processed correctly

* minor corrections
This commit is contained in:
Pavel Esir 2020-11-20 11:42:42 +03:00 committed by GitHub
parent b5930eb58e
commit 1076d32467
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 104 additions and 77 deletions

View File

@ -917,44 +917,35 @@ def parse_tuple_pairs(argv_values: str):
if not argv_values: if not argv_values:
return res return res
data_str = argv_values matches = [m for m in re.finditer(r'[(\[]([0-9., -]+)[)\]]', argv_values, re.IGNORECASE)]
while True:
tuples_matches = re.findall(r'[(\[]([0-9., -]+)[)\]]', data_str, re.IGNORECASE)
if not tuples_matches :
raise Error(
"Mean/scale values should be in format: data(1,2,3),info(2,3,4)" +
" or just plain set of them without naming any inputs: (1,2,3),(2,3,4). " +
refer_to_faq_msg(101), argv_values)
tuple_value = tuples_matches[0]
matches = data_str.split(tuple_value)
input_name = matches[0][:-1] error_msg = 'Mean/scale values should consist of name and values specified in round or square brackets' \
if not input_name: 'separated by comma, e.g. data(1,2,3),info[2,3,4],egg[255] or data(1,2,3). Or just plain set of ' \
res = [] 'values without names: (1,2,3),(2,3,4) or [1,2,3],[2,3,4].' + refer_to_faq_msg(101)
# check that other values are specified w/o names if not matches:
words_reg = r'([a-zA-Z]+)' raise Error(error_msg, argv_values)
for i in range(0, len(matches)):
if re.search(words_reg, matches[i]) is not None:
# error - tuple with name is also specified
raise Error(
"Mean/scale values should either contain names of input layers: data(1,2,3),info(2,3,4)" +
" or just plain set of them without naming any inputs: (1,2,3),(2,3,4)." +
refer_to_faq_msg(101), argv_values)
for match in tuples_matches:
res.append(np.fromstring(match, dtype=float, sep=','))
break
res[input_name] = np.fromstring(tuple_value, dtype=float, sep=',') name_start_idx = 0
name_was_present = False
for idx, match in enumerate(matches):
input_name = argv_values[name_start_idx:match.start(0)]
name_start_idx = match.end(0) + 1
tuple_value = np.fromstring(match.groups()[0], dtype=float, sep=',')
parenthesis = matches[0][-1] if idx != 0 and (name_was_present ^ bool(input_name)):
sibling = ')' if parenthesis == '(' else ']' # if node name firstly was specified and then subsequently not or vice versa
pair = '{}{}{}{}'.format(input_name, parenthesis, tuple_value, sibling) # e.g. (255),input[127] or input(255),[127]
idx_substr = data_str.index(pair) raise Error(error_msg, argv_values)
data_str = data_str[idx_substr + len(pair) + 1:]
if not data_str: name_was_present = True if input_name != "" else False
break if name_was_present:
res[input_name] = tuple_value
else:
res[idx] = tuple_value
if not name_was_present:
# return a list instead of a dictionary
res = sorted(res.values(), key=lambda v: v[0])
return res return res
@ -1183,4 +1174,3 @@ def get_meta_info(argv: argparse.Namespace):
if key in meta_data: if key in meta_data:
meta_data[key] = ','.join([os.path.join('DIR', os.path.split(i)[1]) for i in meta_data[key].split(',')]) meta_data[key] = ','.join([os.path.join('DIR', os.path.split(i)[1]) for i in meta_data[key].split(',')])
return meta_data return meta_data

View File

@ -22,6 +22,7 @@ import unittest
from unittest.mock import patch from unittest.mock import patch
import numpy as np import numpy as np
import numpy.testing as npt
from mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_mean_scale_dictionary, get_model_name, \ from mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_mean_scale_dictionary, get_model_name, \
parse_tuple_pairs, check_positive, writable_dir, readable_dirs, \ parse_tuple_pairs, check_positive, writable_dir, readable_dirs, \
@ -38,7 +39,17 @@ class TestingMeanScaleGetter(unittest.TestCase):
'info': np.array([2.2, 33.33, 444.444]) 'info': np.array([2.2, 33.33, 444.444])
} }
for el in exp_res.keys(): for el in exp_res.keys():
np.array_equal(result[el], exp_res[el]) npt.assert_array_equal(result[el], exp_res[el])
def test_tuple_parser_name_digits_only(self):
tuple_values = "0448(1.1,22.22,333.333),0449[2.2,33.33,444.444]"
result = parse_tuple_pairs(tuple_values)
exp_res = {
'0448': np.array([1.1, 22.22, 333.333]),
'0449': np.array([2.2, 33.33, 444.444])
}
for el in exp_res.keys():
npt.assert_array_equal(result[el], exp_res[el])
def test_tuple_parser_same_values(self): def test_tuple_parser_same_values(self):
tuple_values = "data(1.1,22.22,333.333),info[1.1,22.22,333.333]" tuple_values = "data(1.1,22.22,333.333),info[1.1,22.22,333.333]"
@ -48,7 +59,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
'info': np.array([1.1, 22.22, 333.333]) 'info': np.array([1.1, 22.22, 333.333])
} }
for el in exp_res.keys(): for el in exp_res.keys():
np.array_equal(result[el], exp_res[el]) npt.assert_array_equal(result[el], exp_res[el])
def test_tuple_parser_no_inputs(self): def test_tuple_parser_no_inputs(self):
tuple_values = "(1.1,22.22,333.333),[2.2,33.33,444.444]" tuple_values = "(1.1,22.22,333.333),[2.2,33.33,444.444]"
@ -56,12 +67,24 @@ class TestingMeanScaleGetter(unittest.TestCase):
exp_res = [np.array([1.1, 22.22, 333.333]), exp_res = [np.array([1.1, 22.22, 333.333]),
np.array([2.2, 33.33, 444.444])] np.array([2.2, 33.33, 444.444])]
for i in range(0, len(exp_res)): for i in range(0, len(exp_res)):
np.array_equal(result[i], exp_res[i]) npt.assert_array_equal(result[i], exp_res[i])
def test_tuple_parser_error(self): def test_tuple_parser_error_mixed_with_and_without_name(self):
tuple_values = "(1.1,22.22,333.333),data[2.2,33.33,444.444]" tuple_values = "(1.1,22.22,333.333),data[2.2,33.33,444.444]"
self.assertRaises(Error, parse_tuple_pairs, tuple_values) self.assertRaises(Error, parse_tuple_pairs, tuple_values)
def test_tuple_parser_error_mixed_with_and_without_name_1(self):
tuple_values = "data(1.1,22.22,333.333),[2.2,33.33,444.444]"
self.assertRaises(Error, parse_tuple_pairs, tuple_values)
def test_tuple_parser_error_mixed_with_and_without_name_digits(self):
tuple_values = "(0.1,22.22,333.333),0448[2.2,33.33,444.444]"
self.assertRaises(Error, parse_tuple_pairs, tuple_values)
def test_tuple_parser_error_mixed_with_and_without_name_digits_1(self):
tuple_values = "447(1.1,22.22,333.333),[2.2,33.33,444.444]"
self.assertRaises(Error, parse_tuple_pairs, tuple_values)
def test_mean_scale_no_input(self): def test_mean_scale_no_input(self):
mean_values = "data(1.1,22.22,333.333)" mean_values = "data(1.1,22.22,333.333)"
scale_values = "info[1.1,22.22,333.333]" scale_values = "info[1.1,22.22,333.333]"
@ -79,7 +102,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
for input in exp_res.keys(): for input in exp_res.keys():
for key in exp_res[input].keys(): for key in exp_res[input].keys():
if type(exp_res[input][key]) is np.ndarray: if type(exp_res[input][key]) is np.ndarray:
np.array_equal(exp_res[input][key], result[input][key]) npt.assert_array_equal(exp_res[input][key], result[input][key])
else: else:
self.assertEqual(exp_res[input][key], result[input][key]) self.assertEqual(exp_res[input][key], result[input][key])
@ -100,7 +123,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
for input in exp_res.keys(): for input in exp_res.keys():
for key in exp_res[input].keys(): for key in exp_res[input].keys():
if type(exp_res[input][key]) is np.ndarray: if type(exp_res[input][key]) is np.ndarray:
np.array_equal(exp_res[input][key], result[input][key]) npt.assert_array_equal(exp_res[input][key], result[input][key])
else: else:
self.assertEqual(exp_res[input][key], result[input][key]) self.assertEqual(exp_res[input][key], result[input][key])
@ -116,7 +139,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
for input in exp_res.keys(): for input in exp_res.keys():
for key in exp_res[input].keys(): for key in exp_res[input].keys():
if type(exp_res[input][key]) is np.ndarray: if type(exp_res[input][key]) is np.ndarray:
np.array_equal(exp_res[input][key], result[input][key]) npt.assert_array_equal(exp_res[input][key], result[input][key])
else: else:
self.assertEqual(exp_res[input][key], result[input][key]) self.assertEqual(exp_res[input][key], result[input][key])
@ -132,7 +155,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
for input in exp_res.keys(): for input in exp_res.keys():
for key in exp_res[input].keys(): for key in exp_res[input].keys():
if type(exp_res[input][key]) is np.ndarray: if type(exp_res[input][key]) is np.ndarray:
np.array_equal(exp_res[input][key], result[input][key]) npt.assert_array_equal(exp_res[input][key], result[input][key])
else: else:
self.assertEqual(exp_res[input][key], result[input][key]) self.assertEqual(exp_res[input][key], result[input][key])
@ -151,7 +174,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
for i in range(len(exp_res)): for i in range(len(exp_res)):
for j in range(len(exp_res[i])): for j in range(len(exp_res[i])):
if type(exp_res[i][j]) is np.ndarray: if type(exp_res[i][j]) is np.ndarray:
np.array_equal(exp_res[i][j], result[i][j]) npt.assert_array_equal(exp_res[i][j], result[i][j])
else: else:
self.assertEqual(exp_res[i][j], result[i][j]) self.assertEqual(exp_res[i][j], result[i][j])
@ -170,7 +193,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
for input in exp_res.keys(): for input in exp_res.keys():
for key in exp_res[input].keys(): for key in exp_res[input].keys():
if type(exp_res[input][key]) is np.ndarray: if type(exp_res[input][key]) is np.ndarray:
np.array_equal(exp_res[input][key], result[input][key]) npt.assert_array_equal(exp_res[input][key], result[input][key])
else: else:
self.assertEqual(exp_res[input][key], result[input][key]) self.assertEqual(exp_res[input][key], result[input][key])
@ -193,7 +216,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
for input in exp_res.keys(): for input in exp_res.keys():
for key in exp_res[input].keys(): for key in exp_res[input].keys():
if type(exp_res[input][key]) is np.ndarray: if type(exp_res[input][key]) is np.ndarray:
np.array_equal(exp_res[input][key], result[input][key]) npt.assert_array_equal(exp_res[input][key], result[input][key])
else: else:
self.assertEqual(exp_res[input][key], result[input][key]) self.assertEqual(exp_res[input][key], result[input][key])
@ -216,7 +239,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
for input in exp_res.keys(): for input in exp_res.keys():
for key in exp_res[input].keys(): for key in exp_res[input].keys():
if type(exp_res[input][key]) is np.ndarray: if type(exp_res[input][key]) is np.ndarray:
np.array_equal(exp_res[input][key], result[input][key]) npt.assert_array_equal(exp_res[input][key], result[input][key])
else: else:
self.assertEqual(exp_res[input][key], result[input][key]) self.assertEqual(exp_res[input][key], result[input][key])
@ -235,7 +258,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
for input in exp_res.keys(): for input in exp_res.keys():
for key in exp_res[input].keys(): for key in exp_res[input].keys():
if type(exp_res[input][key]) is np.ndarray: if type(exp_res[input][key]) is np.ndarray:
np.array_equal(exp_res[input][key], result[input][key]) npt.assert_array_equal(exp_res[input][key], result[input][key])
else: else:
self.assertEqual(exp_res[input][key], result[input][key]) self.assertEqual(exp_res[input][key], result[input][key])
@ -258,7 +281,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
for i in range(len(exp_res)): for i in range(len(exp_res)):
for j in range(len(exp_res[i])): for j in range(len(exp_res[i])):
if type(exp_res[i][j]) is np.ndarray: if type(exp_res[i][j]) is np.ndarray:
np.array_equal(exp_res[i][j], result[i][j]) npt.assert_array_equal(exp_res[i][j], result[i][j])
else: else:
self.assertEqual(exp_res[i][j], result[i][j]) self.assertEqual(exp_res[i][j], result[i][j])
@ -279,7 +302,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
for input in exp_res.keys(): for input in exp_res.keys():
for key in exp_res[input].keys(): for key in exp_res[input].keys():
if type(exp_res[input][key]) is np.ndarray: if type(exp_res[input][key]) is np.ndarray:
np.array_equal(exp_res[input][key], result[input][key]) npt.assert_array_equal(exp_res[input][key], result[input][key])
else: else:
self.assertEqual(exp_res[input][key], result[input][key]) self.assertEqual(exp_res[input][key], result[input][key])
@ -301,7 +324,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
for input in exp_res.keys(): for input in exp_res.keys():
for key in exp_res[input].keys(): for key in exp_res[input].keys():
if type(exp_res[input][key]) is np.ndarray: if type(exp_res[input][key]) is np.ndarray:
np.array_equal(exp_res[input][key], result[input][key]) npt.assert_array_equal(exp_res[input][key], result[input][key])
else: else:
self.assertEqual(exp_res[input][key], result[input][key]) self.assertEqual(exp_res[input][key], result[input][key])
@ -323,7 +346,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
for input in exp_res.keys(): for input in exp_res.keys():
for key in exp_res[input].keys(): for key in exp_res[input].keys():
if type(exp_res[input][key]) is np.ndarray: if type(exp_res[input][key]) is np.ndarray:
np.array_equal(exp_res[input][key], result[input][key]) npt.assert_array_equal(exp_res[input][key], result[input][key])
else: else:
self.assertEqual(exp_res[input][key], result[input][key]) self.assertEqual(exp_res[input][key], result[input][key])
@ -343,7 +366,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
] ]
for i in range(0, len(exp_res)): for i in range(0, len(exp_res)):
for j in range(0, len(exp_res[i])): for j in range(0, len(exp_res[i])):
np.array_equal(exp_res[i][j], result[i][j]) npt.assert_array_equal(exp_res[i][j], result[i][j])
def test_scale_do_not_match_input(self): def test_scale_do_not_match_input(self):
scale_values = parse_tuple_pairs("input_not_present(255),input2(255)") scale_values = parse_tuple_pairs("input_not_present(255),input2(255)")
@ -355,6 +378,20 @@ class TestingMeanScaleGetter(unittest.TestCase):
mean_values = parse_tuple_pairs("input_not_present(255),input2(255)") mean_values = parse_tuple_pairs("input_not_present(255),input2(255)")
self.assertRaises(Error, get_mean_scale_dictionary, mean_values, scale_values, "input1,input2") self.assertRaises(Error, get_mean_scale_dictionary, mean_values, scale_values, "input1,input2")
def test_values_match_input_name(self):
# to be sure that we correctly processes complex names
res_values = parse_tuple_pairs("input255(255),input255.0(255.0),multi-dotted.input.3.(255,128,64)")
exp_res = {'input255': np.array([255.0]),
'input255.0': np.array([255.0]),
'multi-dotted.input.3.': np.array([255., 128., 64.])}
self.assertEqual(len(exp_res), len(res_values))
for i, j in zip(exp_res, res_values):
self.assertEqual(i, j)
npt.assert_array_equal(exp_res[i], res_values[j])
def test_input_without_values(self):
self.assertRaises(Error, parse_tuple_pairs, "input1,input2")
class TestSingleTupleParsing(unittest.TestCase): class TestSingleTupleParsing(unittest.TestCase):
def test_get_values_ideal(self): def test_get_values_ideal(self):
values = "(1.11, 22.22, 333.333)" values = "(1.11, 22.22, 333.333)"
@ -425,7 +462,7 @@ class TestShapesParsing(unittest.TestCase):
exp_res = {'inp1': np.array([1, 22, 333, 123]), 'inp2': np.array([-1, 45, 7, 1])} exp_res = {'inp1': np.array([1, 22, 333, 123]), 'inp2': np.array([-1, 45, 7, 1])}
self.assertEqual(list(exp_res.keys()), list(result.keys())) self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys(): for i in exp_res.keys():
np.testing.assert_array_equal(result[i], exp_res[i]) npt.assert_array_equal(result[i], exp_res[i])
def test_get_shapes_several_inputs_several_shapes2(self): def test_get_shapes_several_inputs_several_shapes2(self):
# shapes specified using --input command line parameter and no values # shapes specified using --input command line parameter and no values
@ -434,13 +471,13 @@ class TestShapesParsing(unittest.TestCase):
exp_res = {'inp1': np.array([1, 22, 333, 123]), 'inp2': np.array([-1, 45, 7, 1])} exp_res = {'inp1': np.array([1, 22, 333, 123]), 'inp2': np.array([-1, 45, 7, 1])}
self.assertEqual(list(exp_res.keys()), list(result.keys())) self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys(): for i in exp_res.keys():
np.testing.assert_array_equal(result[i], exp_res[i]) npt.assert_array_equal(result[i], exp_res[i])
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None) placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
placeholder_values_ref = {} placeholder_values_ref = {}
input_node_names_ref = "inp1,inp2" input_node_names_ref = "inp1,inp2"
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys())) self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
for i in placeholder_values_ref.keys(): for i in placeholder_values_ref.keys():
np.testing.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
def test_get_shapes_several_inputs_several_shapes3(self): def test_get_shapes_several_inputs_several_shapes3(self):
# shapes and value for freezing specified using --input command line parameter # shapes and value for freezing specified using --input command line parameter
@ -449,13 +486,13 @@ class TestShapesParsing(unittest.TestCase):
exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array([5])} exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array([5])}
self.assertEqual(list(exp_res.keys()), list(result.keys())) self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys(): for i in exp_res.keys():
np.testing.assert_array_equal(result[i], exp_res[i]) npt.assert_array_equal(result[i], exp_res[i])
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None) placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])} placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
input_node_names_ref = "inp1,inp2,inp3" input_node_names_ref = "inp1,inp2,inp3"
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys())) self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
for i in placeholder_values_ref.keys(): for i in placeholder_values_ref.keys():
np.testing.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
def test_get_shapes_several_inputs_several_shapes4(self): def test_get_shapes_several_inputs_several_shapes4(self):
# shapes specified using --input_shape and values for freezing using --input command line parameter # shapes specified using --input_shape and values for freezing using --input command line parameter
@ -465,13 +502,13 @@ class TestShapesParsing(unittest.TestCase):
exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array([5])} exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array([5])}
self.assertEqual(list(exp_res.keys()), list(result.keys())) self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys(): for i in exp_res.keys():
np.testing.assert_array_equal(result[i], exp_res[i]) npt.assert_array_equal(result[i], exp_res[i])
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None) placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])} placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
input_node_names_ref = "inp1,inp2,inp3" input_node_names_ref = "inp1,inp2,inp3"
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys())) self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
for i in placeholder_values_ref.keys(): for i in placeholder_values_ref.keys():
np.testing.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
self.assertEqual(input_node_names_ref, input_node_names_res) self.assertEqual(input_node_names_ref, input_node_names_res)
def test_get_shapes_several_inputs_several_shapes5(self): def test_get_shapes_several_inputs_several_shapes5(self):
@ -484,14 +521,14 @@ class TestShapesParsing(unittest.TestCase):
exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array([5])} exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array([5])}
self.assertEqual(list(exp_res.keys()), list(result.keys())) self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys(): for i in exp_res.keys():
np.testing.assert_array_equal(result[i], exp_res[i]) npt.assert_array_equal(result[i], exp_res[i])
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, argv_freeze_placeholder_with_value) placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, argv_freeze_placeholder_with_value)
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'],), placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'],),
'inp2': np.array(['5.0', '7.0', '3.0']), 'inp4': np.array(['100.0', '200.0'])} 'inp2': np.array(['5.0', '7.0', '3.0']), 'inp4': np.array(['100.0', '200.0'])}
input_node_names_ref = "inp1,inp2,inp3" input_node_names_ref = "inp1,inp2,inp3"
self.assertEqual(sorted(list(placeholder_values_res.keys())), sorted(list(placeholder_values_ref.keys()))) self.assertEqual(sorted(list(placeholder_values_res.keys())), sorted(list(placeholder_values_ref.keys())))
for i in placeholder_values_ref.keys(): for i in placeholder_values_ref.keys():
np.testing.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
self.assertEqual(input_node_names_ref, input_node_names_res) self.assertEqual(input_node_names_ref, input_node_names_res)
def test_get_shapes_several_inputs_several_shapes6(self): def test_get_shapes_several_inputs_several_shapes6(self):
@ -501,12 +538,12 @@ class TestShapesParsing(unittest.TestCase):
exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array(False).shape} exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array(False).shape}
self.assertEqual(list(exp_res.keys()), list(result.keys())) self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys(): for i in exp_res.keys():
np.testing.assert_array_equal(result[i], exp_res[i]) npt.assert_array_equal(result[i], exp_res[i])
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None) placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': False} placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': False}
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys())) self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
for i in placeholder_values_ref.keys(): for i in placeholder_values_ref.keys():
np.testing.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
def test_get_shapes_several_inputs_several_shapes7(self): def test_get_shapes_several_inputs_several_shapes7(self):
# 0D shape and value for freezing specified using --input command line parameter # 0D shape and value for freezing specified using --input command line parameter
@ -515,12 +552,12 @@ class TestShapesParsing(unittest.TestCase):
exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array(False).shape} exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array(False).shape}
self.assertEqual(list(exp_res.keys()), list(result.keys())) self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys(): for i in exp_res.keys():
np.testing.assert_array_equal(result[i], exp_res[i]) npt.assert_array_equal(result[i], exp_res[i])
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None) placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': True} placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': True}
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys())) self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
for i in placeholder_values_ref.keys(): for i in placeholder_values_ref.keys():
np.testing.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
def test_get_shapes_and_data_types1(self): def test_get_shapes_and_data_types1(self):
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3 2 3]{i32},inp3[5]{f32}->[1.0 1.0 2.0 3.0 5.0]" argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3 2 3]{i32},inp3[5]{f32}->[1.0 1.0 2.0 3.0 5.0]"
@ -529,7 +566,7 @@ class TestShapesParsing(unittest.TestCase):
ref_result_data_types = {'inp2': np.int32, 'inp3': np.float32} ref_result_data_types = {'inp2': np.int32, 'inp3': np.float32}
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys())) self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
for i in ref_result_shapes.keys(): for i in ref_result_shapes.keys():
np.testing.assert_array_equal(result_shapes[i], ref_result_shapes[i]) npt.assert_array_equal(result_shapes[i], ref_result_shapes[i])
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys())) self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
for i in ref_result_data_types.keys(): for i in ref_result_data_types.keys():
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i]) np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
@ -541,7 +578,7 @@ class TestShapesParsing(unittest.TestCase):
ref_result_data_types = {'inp2': np.int32, '0:inp3': np.float32} ref_result_data_types = {'inp2': np.int32, '0:inp3': np.float32}
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys())) self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
for i in ref_result_shapes.keys(): for i in ref_result_shapes.keys():
np.testing.assert_array_equal(result_shapes[i], ref_result_shapes[i]) npt.assert_array_equal(result_shapes[i], ref_result_shapes[i])
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys())) self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
for i in ref_result_data_types.keys(): for i in ref_result_data_types.keys():
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i]) np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
@ -553,7 +590,7 @@ class TestShapesParsing(unittest.TestCase):
ref_result_data_types = {'inp2': np.int32, 'inp3:4': np.float32} ref_result_data_types = {'inp2': np.int32, 'inp3:4': np.float32}
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys())) self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
for i in ref_result_shapes.keys(): for i in ref_result_shapes.keys():
np.testing.assert_array_equal(result_shapes[i], ref_result_shapes[i]) npt.assert_array_equal(result_shapes[i], ref_result_shapes[i])
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys())) self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
for i in ref_result_data_types.keys(): for i in ref_result_data_types.keys():
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i]) np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
@ -566,7 +603,7 @@ class TestShapesParsing(unittest.TestCase):
ref_result_data_types = {} ref_result_data_types = {}
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys())) self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
for i in ref_result_shapes.keys(): for i in ref_result_shapes.keys():
np.testing.assert_array_equal(result_shapes[i], ref_result_shapes[i]) npt.assert_array_equal(result_shapes[i], ref_result_shapes[i])
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys())) self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
for i in ref_result_data_types.keys(): for i in ref_result_data_types.keys():
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i]) np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
@ -579,7 +616,7 @@ class TestShapesParsing(unittest.TestCase):
ref_result_data_types = {} ref_result_data_types = {}
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys())) self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
for i in ref_result_shapes.keys(): for i in ref_result_shapes.keys():
np.testing.assert_array_equal(result_shapes[i], ref_result_shapes[i]) npt.assert_array_equal(result_shapes[i], ref_result_shapes[i])
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys())) self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
for i in ref_result_data_types.keys(): for i in ref_result_data_types.keys():
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i]) np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
@ -592,7 +629,7 @@ class TestShapesParsing(unittest.TestCase):
ref_result_data_types = {'placeholder1': np.int32, 'placeholder3': np.int32} ref_result_data_types = {'placeholder1': np.int32, 'placeholder3': np.int32}
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys())) self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
for i in ref_result_shapes.keys(): for i in ref_result_shapes.keys():
np.testing.assert_array_equal(result_shapes[i], ref_result_shapes[i]) npt.assert_array_equal(result_shapes[i], ref_result_shapes[i])
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys())) self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
for i in ref_result_data_types.keys(): for i in ref_result_data_types.keys():
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i]) np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
@ -641,28 +678,28 @@ class TestShapesParsing(unittest.TestCase):
exp_res = {'inp1': np.array([1, 22, 333, 123])} exp_res = {'inp1': np.array([1, 22, 333, 123])}
self.assertEqual(list(exp_res.keys()), list(result.keys())) self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys(): for i in exp_res.keys():
np.testing.assert_array_equal(result[i], exp_res[i]) npt.assert_array_equal(result[i], exp_res[i])
def test_get_shapes_no_input_no_shape(self): def test_get_shapes_no_input_no_shape(self):
argv_input = "" argv_input = ""
input_shapes = "" input_shapes = ""
result, _ = get_placeholder_shapes(argv_input, input_shapes) result, _ = get_placeholder_shapes(argv_input, input_shapes)
exp_res = np.array([None]) exp_res = np.array([None])
np.testing.assert_array_equal(result, exp_res) npt.assert_array_equal(result, exp_res)
def test_get_shapes_no_input_one_shape(self): def test_get_shapes_no_input_one_shape(self):
argv_input = "" argv_input = ""
input_shapes = "(12,4,1)" input_shapes = "(12,4,1)"
result, _ = get_placeholder_shapes(argv_input, input_shapes) result, _ = get_placeholder_shapes(argv_input, input_shapes)
exp_res = np.array([12, 4, 1]) exp_res = np.array([12, 4, 1])
np.testing.assert_array_equal(result, exp_res) npt.assert_array_equal(result, exp_res)
def test_get_shapes_no_input_one_shape2(self): def test_get_shapes_no_input_one_shape2(self):
argv_input = "" argv_input = ""
input_shapes = "[12,4,1]" input_shapes = "[12,4,1]"
result, _ = get_placeholder_shapes(argv_input, input_shapes) result, _ = get_placeholder_shapes(argv_input, input_shapes)
exp_res = np.array([12, 4, 1]) exp_res = np.array([12, 4, 1])
np.testing.assert_array_equal(result, exp_res) npt.assert_array_equal(result, exp_res)
def test_get_shapes_no_input_two_shapes(self): def test_get_shapes_no_input_two_shapes(self):
argv_input = "" argv_input = ""
@ -676,7 +713,7 @@ class TestShapesParsing(unittest.TestCase):
exp_res = {'inp1': np.array([None])} exp_res = {'inp1': np.array([None])}
self.assertEqual(list(exp_res.keys()), list(result.keys())) self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys(): for i in exp_res.keys():
np.testing.assert_array_equal(result[i], exp_res[i]) npt.assert_array_equal(result[i], exp_res[i])
def test_get_shapes_one_input_wrong_shape8(self): def test_get_shapes_one_input_wrong_shape8(self):
argv_input = "inp1" argv_input = "inp1"
@ -735,7 +772,7 @@ class TestShapesParsing(unittest.TestCase):
exp_res = {'inp1': np.array([-1, 4, 1]), 'inp2': np.array([4, 6, 8])} exp_res = {'inp1': np.array([-1, 4, 1]), 'inp2': np.array([4, 6, 8])}
self.assertEqual(list(exp_res.keys()), list(result.keys())) self.assertEqual(list(exp_res.keys()), list(result.keys()))
for i in exp_res.keys(): for i in exp_res.keys():
np.testing.assert_array_equal(result[i], exp_res[i]) npt.assert_array_equal(result[i], exp_res[i])
def test_get_shapes_one_input_first_neg_shape_not_one(self): def test_get_shapes_one_input_first_neg_shape_not_one(self):
argv_input = "inp1" argv_input = "inp1"