[MO] cli_parser fix when input contains substring with matching scale/mean values (#3146)
* fix MO cli_parser when input contains substring with matching scale/mean values * some additions to cli_parser unit-tests * fixed numpy array comparisons -- added assert_ prefix * more general solution for mean/scale cli_parser, names with only digit values are processed correctly * minor corrections
This commit is contained in:
parent
b5930eb58e
commit
1076d32467
@ -917,44 +917,35 @@ def parse_tuple_pairs(argv_values: str):
|
||||
if not argv_values:
|
||||
return res
|
||||
|
||||
data_str = argv_values
|
||||
while True:
|
||||
tuples_matches = re.findall(r'[(\[]([0-9., -]+)[)\]]', data_str, re.IGNORECASE)
|
||||
if not tuples_matches :
|
||||
raise Error(
|
||||
"Mean/scale values should be in format: data(1,2,3),info(2,3,4)" +
|
||||
" or just plain set of them without naming any inputs: (1,2,3),(2,3,4). " +
|
||||
refer_to_faq_msg(101), argv_values)
|
||||
tuple_value = tuples_matches[0]
|
||||
matches = data_str.split(tuple_value)
|
||||
matches = [m for m in re.finditer(r'[(\[]([0-9., -]+)[)\]]', argv_values, re.IGNORECASE)]
|
||||
|
||||
input_name = matches[0][:-1]
|
||||
if not input_name:
|
||||
res = []
|
||||
# check that other values are specified w/o names
|
||||
words_reg = r'([a-zA-Z]+)'
|
||||
for i in range(0, len(matches)):
|
||||
if re.search(words_reg, matches[i]) is not None:
|
||||
# error - tuple with name is also specified
|
||||
raise Error(
|
||||
"Mean/scale values should either contain names of input layers: data(1,2,3),info(2,3,4)" +
|
||||
" or just plain set of them without naming any inputs: (1,2,3),(2,3,4)." +
|
||||
refer_to_faq_msg(101), argv_values)
|
||||
for match in tuples_matches:
|
||||
res.append(np.fromstring(match, dtype=float, sep=','))
|
||||
break
|
||||
error_msg = 'Mean/scale values should consist of name and values specified in round or square brackets' \
|
||||
'separated by comma, e.g. data(1,2,3),info[2,3,4],egg[255] or data(1,2,3). Or just plain set of ' \
|
||||
'values without names: (1,2,3),(2,3,4) or [1,2,3],[2,3,4].' + refer_to_faq_msg(101)
|
||||
if not matches:
|
||||
raise Error(error_msg, argv_values)
|
||||
|
||||
res[input_name] = np.fromstring(tuple_value, dtype=float, sep=',')
|
||||
name_start_idx = 0
|
||||
name_was_present = False
|
||||
for idx, match in enumerate(matches):
|
||||
input_name = argv_values[name_start_idx:match.start(0)]
|
||||
name_start_idx = match.end(0) + 1
|
||||
tuple_value = np.fromstring(match.groups()[0], dtype=float, sep=',')
|
||||
|
||||
parenthesis = matches[0][-1]
|
||||
sibling = ')' if parenthesis == '(' else ']'
|
||||
pair = '{}{}{}{}'.format(input_name, parenthesis, tuple_value, sibling)
|
||||
idx_substr = data_str.index(pair)
|
||||
data_str = data_str[idx_substr + len(pair) + 1:]
|
||||
if idx != 0 and (name_was_present ^ bool(input_name)):
|
||||
# if node name firstly was specified and then subsequently not or vice versa
|
||||
# e.g. (255),input[127] or input(255),[127]
|
||||
raise Error(error_msg, argv_values)
|
||||
|
||||
if not data_str:
|
||||
break
|
||||
name_was_present = True if input_name != "" else False
|
||||
if name_was_present:
|
||||
res[input_name] = tuple_value
|
||||
else:
|
||||
res[idx] = tuple_value
|
||||
|
||||
if not name_was_present:
|
||||
# return a list instead of a dictionary
|
||||
res = sorted(res.values(), key=lambda v: v[0])
|
||||
return res
|
||||
|
||||
|
||||
@ -1183,4 +1174,3 @@ def get_meta_info(argv: argparse.Namespace):
|
||||
if key in meta_data:
|
||||
meta_data[key] = ','.join([os.path.join('DIR', os.path.split(i)[1]) for i in meta_data[key].split(',')])
|
||||
return meta_data
|
||||
|
||||
|
@ -22,6 +22,7 @@ import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import numpy.testing as npt
|
||||
|
||||
from mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_mean_scale_dictionary, get_model_name, \
|
||||
parse_tuple_pairs, check_positive, writable_dir, readable_dirs, \
|
||||
@ -38,7 +39,17 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
||||
'info': np.array([2.2, 33.33, 444.444])
|
||||
}
|
||||
for el in exp_res.keys():
|
||||
np.array_equal(result[el], exp_res[el])
|
||||
npt.assert_array_equal(result[el], exp_res[el])
|
||||
|
||||
def test_tuple_parser_name_digits_only(self):
|
||||
tuple_values = "0448(1.1,22.22,333.333),0449[2.2,33.33,444.444]"
|
||||
result = parse_tuple_pairs(tuple_values)
|
||||
exp_res = {
|
||||
'0448': np.array([1.1, 22.22, 333.333]),
|
||||
'0449': np.array([2.2, 33.33, 444.444])
|
||||
}
|
||||
for el in exp_res.keys():
|
||||
npt.assert_array_equal(result[el], exp_res[el])
|
||||
|
||||
def test_tuple_parser_same_values(self):
|
||||
tuple_values = "data(1.1,22.22,333.333),info[1.1,22.22,333.333]"
|
||||
@ -48,7 +59,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
||||
'info': np.array([1.1, 22.22, 333.333])
|
||||
}
|
||||
for el in exp_res.keys():
|
||||
np.array_equal(result[el], exp_res[el])
|
||||
npt.assert_array_equal(result[el], exp_res[el])
|
||||
|
||||
def test_tuple_parser_no_inputs(self):
|
||||
tuple_values = "(1.1,22.22,333.333),[2.2,33.33,444.444]"
|
||||
@ -56,12 +67,24 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
||||
exp_res = [np.array([1.1, 22.22, 333.333]),
|
||||
np.array([2.2, 33.33, 444.444])]
|
||||
for i in range(0, len(exp_res)):
|
||||
np.array_equal(result[i], exp_res[i])
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_tuple_parser_error(self):
|
||||
def test_tuple_parser_error_mixed_with_and_without_name(self):
|
||||
tuple_values = "(1.1,22.22,333.333),data[2.2,33.33,444.444]"
|
||||
self.assertRaises(Error, parse_tuple_pairs, tuple_values)
|
||||
|
||||
def test_tuple_parser_error_mixed_with_and_without_name_1(self):
|
||||
tuple_values = "data(1.1,22.22,333.333),[2.2,33.33,444.444]"
|
||||
self.assertRaises(Error, parse_tuple_pairs, tuple_values)
|
||||
|
||||
def test_tuple_parser_error_mixed_with_and_without_name_digits(self):
|
||||
tuple_values = "(0.1,22.22,333.333),0448[2.2,33.33,444.444]"
|
||||
self.assertRaises(Error, parse_tuple_pairs, tuple_values)
|
||||
|
||||
def test_tuple_parser_error_mixed_with_and_without_name_digits_1(self):
|
||||
tuple_values = "447(1.1,22.22,333.333),[2.2,33.33,444.444]"
|
||||
self.assertRaises(Error, parse_tuple_pairs, tuple_values)
|
||||
|
||||
def test_mean_scale_no_input(self):
|
||||
mean_values = "data(1.1,22.22,333.333)"
|
||||
scale_values = "info[1.1,22.22,333.333]"
|
||||
@ -79,7 +102,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
||||
for input in exp_res.keys():
|
||||
for key in exp_res[input].keys():
|
||||
if type(exp_res[input][key]) is np.ndarray:
|
||||
np.array_equal(exp_res[input][key], result[input][key])
|
||||
npt.assert_array_equal(exp_res[input][key], result[input][key])
|
||||
else:
|
||||
self.assertEqual(exp_res[input][key], result[input][key])
|
||||
|
||||
@ -100,7 +123,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
||||
for input in exp_res.keys():
|
||||
for key in exp_res[input].keys():
|
||||
if type(exp_res[input][key]) is np.ndarray:
|
||||
np.array_equal(exp_res[input][key], result[input][key])
|
||||
npt.assert_array_equal(exp_res[input][key], result[input][key])
|
||||
else:
|
||||
self.assertEqual(exp_res[input][key], result[input][key])
|
||||
|
||||
@ -116,7 +139,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
||||
for input in exp_res.keys():
|
||||
for key in exp_res[input].keys():
|
||||
if type(exp_res[input][key]) is np.ndarray:
|
||||
np.array_equal(exp_res[input][key], result[input][key])
|
||||
npt.assert_array_equal(exp_res[input][key], result[input][key])
|
||||
else:
|
||||
self.assertEqual(exp_res[input][key], result[input][key])
|
||||
|
||||
@ -132,7 +155,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
||||
for input in exp_res.keys():
|
||||
for key in exp_res[input].keys():
|
||||
if type(exp_res[input][key]) is np.ndarray:
|
||||
np.array_equal(exp_res[input][key], result[input][key])
|
||||
npt.assert_array_equal(exp_res[input][key], result[input][key])
|
||||
else:
|
||||
self.assertEqual(exp_res[input][key], result[input][key])
|
||||
|
||||
@ -151,7 +174,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
||||
for i in range(len(exp_res)):
|
||||
for j in range(len(exp_res[i])):
|
||||
if type(exp_res[i][j]) is np.ndarray:
|
||||
np.array_equal(exp_res[i][j], result[i][j])
|
||||
npt.assert_array_equal(exp_res[i][j], result[i][j])
|
||||
else:
|
||||
self.assertEqual(exp_res[i][j], result[i][j])
|
||||
|
||||
@ -170,7 +193,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
||||
for input in exp_res.keys():
|
||||
for key in exp_res[input].keys():
|
||||
if type(exp_res[input][key]) is np.ndarray:
|
||||
np.array_equal(exp_res[input][key], result[input][key])
|
||||
npt.assert_array_equal(exp_res[input][key], result[input][key])
|
||||
else:
|
||||
self.assertEqual(exp_res[input][key], result[input][key])
|
||||
|
||||
@ -193,7 +216,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
||||
for input in exp_res.keys():
|
||||
for key in exp_res[input].keys():
|
||||
if type(exp_res[input][key]) is np.ndarray:
|
||||
np.array_equal(exp_res[input][key], result[input][key])
|
||||
npt.assert_array_equal(exp_res[input][key], result[input][key])
|
||||
else:
|
||||
self.assertEqual(exp_res[input][key], result[input][key])
|
||||
|
||||
@ -216,7 +239,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
||||
for input in exp_res.keys():
|
||||
for key in exp_res[input].keys():
|
||||
if type(exp_res[input][key]) is np.ndarray:
|
||||
np.array_equal(exp_res[input][key], result[input][key])
|
||||
npt.assert_array_equal(exp_res[input][key], result[input][key])
|
||||
else:
|
||||
self.assertEqual(exp_res[input][key], result[input][key])
|
||||
|
||||
@ -235,7 +258,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
||||
for input in exp_res.keys():
|
||||
for key in exp_res[input].keys():
|
||||
if type(exp_res[input][key]) is np.ndarray:
|
||||
np.array_equal(exp_res[input][key], result[input][key])
|
||||
npt.assert_array_equal(exp_res[input][key], result[input][key])
|
||||
else:
|
||||
self.assertEqual(exp_res[input][key], result[input][key])
|
||||
|
||||
@ -258,7 +281,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
||||
for i in range(len(exp_res)):
|
||||
for j in range(len(exp_res[i])):
|
||||
if type(exp_res[i][j]) is np.ndarray:
|
||||
np.array_equal(exp_res[i][j], result[i][j])
|
||||
npt.assert_array_equal(exp_res[i][j], result[i][j])
|
||||
else:
|
||||
self.assertEqual(exp_res[i][j], result[i][j])
|
||||
|
||||
@ -279,7 +302,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
||||
for input in exp_res.keys():
|
||||
for key in exp_res[input].keys():
|
||||
if type(exp_res[input][key]) is np.ndarray:
|
||||
np.array_equal(exp_res[input][key], result[input][key])
|
||||
npt.assert_array_equal(exp_res[input][key], result[input][key])
|
||||
else:
|
||||
self.assertEqual(exp_res[input][key], result[input][key])
|
||||
|
||||
@ -301,7 +324,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
||||
for input in exp_res.keys():
|
||||
for key in exp_res[input].keys():
|
||||
if type(exp_res[input][key]) is np.ndarray:
|
||||
np.array_equal(exp_res[input][key], result[input][key])
|
||||
npt.assert_array_equal(exp_res[input][key], result[input][key])
|
||||
else:
|
||||
self.assertEqual(exp_res[input][key], result[input][key])
|
||||
|
||||
@ -323,7 +346,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
||||
for input in exp_res.keys():
|
||||
for key in exp_res[input].keys():
|
||||
if type(exp_res[input][key]) is np.ndarray:
|
||||
np.array_equal(exp_res[input][key], result[input][key])
|
||||
npt.assert_array_equal(exp_res[input][key], result[input][key])
|
||||
else:
|
||||
self.assertEqual(exp_res[input][key], result[input][key])
|
||||
|
||||
@ -343,7 +366,7 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
||||
]
|
||||
for i in range(0, len(exp_res)):
|
||||
for j in range(0, len(exp_res[i])):
|
||||
np.array_equal(exp_res[i][j], result[i][j])
|
||||
npt.assert_array_equal(exp_res[i][j], result[i][j])
|
||||
|
||||
def test_scale_do_not_match_input(self):
|
||||
scale_values = parse_tuple_pairs("input_not_present(255),input2(255)")
|
||||
@ -355,6 +378,20 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
||||
mean_values = parse_tuple_pairs("input_not_present(255),input2(255)")
|
||||
self.assertRaises(Error, get_mean_scale_dictionary, mean_values, scale_values, "input1,input2")
|
||||
|
||||
def test_values_match_input_name(self):
|
||||
# to be sure that we correctly processes complex names
|
||||
res_values = parse_tuple_pairs("input255(255),input255.0(255.0),multi-dotted.input.3.(255,128,64)")
|
||||
exp_res = {'input255': np.array([255.0]),
|
||||
'input255.0': np.array([255.0]),
|
||||
'multi-dotted.input.3.': np.array([255., 128., 64.])}
|
||||
self.assertEqual(len(exp_res), len(res_values))
|
||||
for i, j in zip(exp_res, res_values):
|
||||
self.assertEqual(i, j)
|
||||
npt.assert_array_equal(exp_res[i], res_values[j])
|
||||
|
||||
def test_input_without_values(self):
|
||||
self.assertRaises(Error, parse_tuple_pairs, "input1,input2")
|
||||
|
||||
class TestSingleTupleParsing(unittest.TestCase):
|
||||
def test_get_values_ideal(self):
|
||||
values = "(1.11, 22.22, 333.333)"
|
||||
@ -425,7 +462,7 @@ class TestShapesParsing(unittest.TestCase):
|
||||
exp_res = {'inp1': np.array([1, 22, 333, 123]), 'inp2': np.array([-1, 45, 7, 1])}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
np.testing.assert_array_equal(result[i], exp_res[i])
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_shapes_several_inputs_several_shapes2(self):
|
||||
# shapes specified using --input command line parameter and no values
|
||||
@ -434,13 +471,13 @@ class TestShapesParsing(unittest.TestCase):
|
||||
exp_res = {'inp1': np.array([1, 22, 333, 123]), 'inp2': np.array([-1, 45, 7, 1])}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
np.testing.assert_array_equal(result[i], exp_res[i])
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||
placeholder_values_ref = {}
|
||||
input_node_names_ref = "inp1,inp2"
|
||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||
for i in placeholder_values_ref.keys():
|
||||
np.testing.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
|
||||
def test_get_shapes_several_inputs_several_shapes3(self):
|
||||
# shapes and value for freezing specified using --input command line parameter
|
||||
@ -449,13 +486,13 @@ class TestShapesParsing(unittest.TestCase):
|
||||
exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array([5])}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
np.testing.assert_array_equal(result[i], exp_res[i])
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
|
||||
input_node_names_ref = "inp1,inp2,inp3"
|
||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||
for i in placeholder_values_ref.keys():
|
||||
np.testing.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
|
||||
def test_get_shapes_several_inputs_several_shapes4(self):
|
||||
# shapes specified using --input_shape and values for freezing using --input command line parameter
|
||||
@ -465,13 +502,13 @@ class TestShapesParsing(unittest.TestCase):
|
||||
exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array([5])}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
np.testing.assert_array_equal(result[i], exp_res[i])
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])}
|
||||
input_node_names_ref = "inp1,inp2,inp3"
|
||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||
for i in placeholder_values_ref.keys():
|
||||
np.testing.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
self.assertEqual(input_node_names_ref, input_node_names_res)
|
||||
|
||||
def test_get_shapes_several_inputs_several_shapes5(self):
|
||||
@ -484,14 +521,14 @@ class TestShapesParsing(unittest.TestCase):
|
||||
exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array([5])}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
np.testing.assert_array_equal(result[i], exp_res[i])
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, argv_freeze_placeholder_with_value)
|
||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'],),
|
||||
'inp2': np.array(['5.0', '7.0', '3.0']), 'inp4': np.array(['100.0', '200.0'])}
|
||||
input_node_names_ref = "inp1,inp2,inp3"
|
||||
self.assertEqual(sorted(list(placeholder_values_res.keys())), sorted(list(placeholder_values_ref.keys())))
|
||||
for i in placeholder_values_ref.keys():
|
||||
np.testing.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
self.assertEqual(input_node_names_ref, input_node_names_res)
|
||||
|
||||
def test_get_shapes_several_inputs_several_shapes6(self):
|
||||
@ -501,12 +538,12 @@ class TestShapesParsing(unittest.TestCase):
|
||||
exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array(False).shape}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
np.testing.assert_array_equal(result[i], exp_res[i])
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': False}
|
||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||
for i in placeholder_values_ref.keys():
|
||||
np.testing.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
|
||||
def test_get_shapes_several_inputs_several_shapes7(self):
|
||||
# 0D shape and value for freezing specified using --input command line parameter
|
||||
@ -515,12 +552,12 @@ class TestShapesParsing(unittest.TestCase):
|
||||
exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array(False).shape}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
np.testing.assert_array_equal(result[i], exp_res[i])
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None)
|
||||
placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': True}
|
||||
self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys()))
|
||||
for i in placeholder_values_ref.keys():
|
||||
np.testing.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i])
|
||||
|
||||
def test_get_shapes_and_data_types1(self):
|
||||
argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3 2 3]{i32},inp3[5]{f32}->[1.0 1.0 2.0 3.0 5.0]"
|
||||
@ -529,7 +566,7 @@ class TestShapesParsing(unittest.TestCase):
|
||||
ref_result_data_types = {'inp2': np.int32, 'inp3': np.float32}
|
||||
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
|
||||
for i in ref_result_shapes.keys():
|
||||
np.testing.assert_array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
npt.assert_array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
|
||||
for i in ref_result_data_types.keys():
|
||||
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
|
||||
@ -541,7 +578,7 @@ class TestShapesParsing(unittest.TestCase):
|
||||
ref_result_data_types = {'inp2': np.int32, '0:inp3': np.float32}
|
||||
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
|
||||
for i in ref_result_shapes.keys():
|
||||
np.testing.assert_array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
npt.assert_array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
|
||||
for i in ref_result_data_types.keys():
|
||||
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
|
||||
@ -553,7 +590,7 @@ class TestShapesParsing(unittest.TestCase):
|
||||
ref_result_data_types = {'inp2': np.int32, 'inp3:4': np.float32}
|
||||
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
|
||||
for i in ref_result_shapes.keys():
|
||||
np.testing.assert_array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
npt.assert_array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
|
||||
for i in ref_result_data_types.keys():
|
||||
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
|
||||
@ -566,7 +603,7 @@ class TestShapesParsing(unittest.TestCase):
|
||||
ref_result_data_types = {}
|
||||
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
|
||||
for i in ref_result_shapes.keys():
|
||||
np.testing.assert_array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
npt.assert_array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
|
||||
for i in ref_result_data_types.keys():
|
||||
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
|
||||
@ -579,7 +616,7 @@ class TestShapesParsing(unittest.TestCase):
|
||||
ref_result_data_types = {}
|
||||
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
|
||||
for i in ref_result_shapes.keys():
|
||||
np.testing.assert_array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
npt.assert_array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
|
||||
for i in ref_result_data_types.keys():
|
||||
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
|
||||
@ -592,7 +629,7 @@ class TestShapesParsing(unittest.TestCase):
|
||||
ref_result_data_types = {'placeholder1': np.int32, 'placeholder3': np.int32}
|
||||
self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys()))
|
||||
for i in ref_result_shapes.keys():
|
||||
np.testing.assert_array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
npt.assert_array_equal(result_shapes[i], ref_result_shapes[i])
|
||||
self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys()))
|
||||
for i in ref_result_data_types.keys():
|
||||
np.testing.assert_equal(result_data_types[i], ref_result_data_types[i])
|
||||
@ -641,28 +678,28 @@ class TestShapesParsing(unittest.TestCase):
|
||||
exp_res = {'inp1': np.array([1, 22, 333, 123])}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
np.testing.assert_array_equal(result[i], exp_res[i])
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_shapes_no_input_no_shape(self):
|
||||
argv_input = ""
|
||||
input_shapes = ""
|
||||
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
exp_res = np.array([None])
|
||||
np.testing.assert_array_equal(result, exp_res)
|
||||
npt.assert_array_equal(result, exp_res)
|
||||
|
||||
def test_get_shapes_no_input_one_shape(self):
|
||||
argv_input = ""
|
||||
input_shapes = "(12,4,1)"
|
||||
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
exp_res = np.array([12, 4, 1])
|
||||
np.testing.assert_array_equal(result, exp_res)
|
||||
npt.assert_array_equal(result, exp_res)
|
||||
|
||||
def test_get_shapes_no_input_one_shape2(self):
|
||||
argv_input = ""
|
||||
input_shapes = "[12,4,1]"
|
||||
result, _ = get_placeholder_shapes(argv_input, input_shapes)
|
||||
exp_res = np.array([12, 4, 1])
|
||||
np.testing.assert_array_equal(result, exp_res)
|
||||
npt.assert_array_equal(result, exp_res)
|
||||
|
||||
def test_get_shapes_no_input_two_shapes(self):
|
||||
argv_input = ""
|
||||
@ -676,7 +713,7 @@ class TestShapesParsing(unittest.TestCase):
|
||||
exp_res = {'inp1': np.array([None])}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
np.testing.assert_array_equal(result[i], exp_res[i])
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_shapes_one_input_wrong_shape8(self):
|
||||
argv_input = "inp1"
|
||||
@ -735,7 +772,7 @@ class TestShapesParsing(unittest.TestCase):
|
||||
exp_res = {'inp1': np.array([-1, 4, 1]), 'inp2': np.array([4, 6, 8])}
|
||||
self.assertEqual(list(exp_res.keys()), list(result.keys()))
|
||||
for i in exp_res.keys():
|
||||
np.testing.assert_array_equal(result[i], exp_res[i])
|
||||
npt.assert_array_equal(result[i], exp_res[i])
|
||||
|
||||
def test_get_shapes_one_input_first_neg_shape_not_one(self):
|
||||
argv_input = "inp1"
|
||||
|
Loading…
Reference in New Issue
Block a user