diff --git a/model-optimizer/mo/utils/cli_parser.py b/model-optimizer/mo/utils/cli_parser.py index bb2918db462..d4ca43b3861 100644 --- a/model-optimizer/mo/utils/cli_parser.py +++ b/model-optimizer/mo/utils/cli_parser.py @@ -917,44 +917,35 @@ def parse_tuple_pairs(argv_values: str): if not argv_values: return res - data_str = argv_values - while True: - tuples_matches = re.findall(r'[(\[]([0-9., -]+)[)\]]', data_str, re.IGNORECASE) - if not tuples_matches : - raise Error( - "Mean/scale values should be in format: data(1,2,3),info(2,3,4)" + - " or just plain set of them without naming any inputs: (1,2,3),(2,3,4). " + - refer_to_faq_msg(101), argv_values) - tuple_value = tuples_matches[0] - matches = data_str.split(tuple_value) + matches = [m for m in re.finditer(r'[(\[]([0-9., -]+)[)\]]', argv_values, re.IGNORECASE)] - input_name = matches[0][:-1] - if not input_name: - res = [] - # check that other values are specified w/o names - words_reg = r'([a-zA-Z]+)' - for i in range(0, len(matches)): - if re.search(words_reg, matches[i]) is not None: - # error - tuple with name is also specified - raise Error( - "Mean/scale values should either contain names of input layers: data(1,2,3),info(2,3,4)" + - " or just plain set of them without naming any inputs: (1,2,3),(2,3,4)." + - refer_to_faq_msg(101), argv_values) - for match in tuples_matches: - res.append(np.fromstring(match, dtype=float, sep=',')) - break + error_msg = 'Mean/scale values should consist of name and values specified in round or square brackets' \ + 'separated by comma, e.g. data(1,2,3),info[2,3,4],egg[255] or data(1,2,3). Or just plain set of ' \ + 'values without names: (1,2,3),(2,3,4) or [1,2,3],[2,3,4].' + refer_to_faq_msg(101) + if not matches: + raise Error(error_msg, argv_values) - res[input_name] = np.fromstring(tuple_value, dtype=float, sep=',') + name_start_idx = 0 + name_was_present = False + for idx, match in enumerate(matches): + input_name = argv_values[name_start_idx:match.start(0)] + name_start_idx = match.end(0) + 1 + tuple_value = np.fromstring(match.groups()[0], dtype=float, sep=',') - parenthesis = matches[0][-1] - sibling = ')' if parenthesis == '(' else ']' - pair = '{}{}{}{}'.format(input_name, parenthesis, tuple_value, sibling) - idx_substr = data_str.index(pair) - data_str = data_str[idx_substr + len(pair) + 1:] + if idx != 0 and (name_was_present ^ bool(input_name)): + # if node name firstly was specified and then subsequently not or vice versa + # e.g. (255),input[127] or input(255),[127] + raise Error(error_msg, argv_values) - if not data_str: - break + name_was_present = True if input_name != "" else False + if name_was_present: + res[input_name] = tuple_value + else: + res[idx] = tuple_value + if not name_was_present: + # return a list instead of a dictionary + res = sorted(res.values(), key=lambda v: v[0]) return res @@ -1183,4 +1174,3 @@ def get_meta_info(argv: argparse.Namespace): if key in meta_data: meta_data[key] = ','.join([os.path.join('DIR', os.path.split(i)[1]) for i in meta_data[key].split(',')]) return meta_data - diff --git a/model-optimizer/mo/utils/cli_parser_test.py b/model-optimizer/mo/utils/cli_parser_test.py index deefa5ad9ae..36ec7cbcd28 100644 --- a/model-optimizer/mo/utils/cli_parser_test.py +++ b/model-optimizer/mo/utils/cli_parser_test.py @@ -22,6 +22,7 @@ import unittest from unittest.mock import patch import numpy as np +import numpy.testing as npt from mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_mean_scale_dictionary, get_model_name, \ parse_tuple_pairs, check_positive, writable_dir, readable_dirs, \ @@ -38,7 +39,17 @@ class TestingMeanScaleGetter(unittest.TestCase): 'info': np.array([2.2, 33.33, 444.444]) } for el in exp_res.keys(): - np.array_equal(result[el], exp_res[el]) + npt.assert_array_equal(result[el], exp_res[el]) + + def test_tuple_parser_name_digits_only(self): + tuple_values = "0448(1.1,22.22,333.333),0449[2.2,33.33,444.444]" + result = parse_tuple_pairs(tuple_values) + exp_res = { + '0448': np.array([1.1, 22.22, 333.333]), + '0449': np.array([2.2, 33.33, 444.444]) + } + for el in exp_res.keys(): + npt.assert_array_equal(result[el], exp_res[el]) def test_tuple_parser_same_values(self): tuple_values = "data(1.1,22.22,333.333),info[1.1,22.22,333.333]" @@ -48,7 +59,7 @@ class TestingMeanScaleGetter(unittest.TestCase): 'info': np.array([1.1, 22.22, 333.333]) } for el in exp_res.keys(): - np.array_equal(result[el], exp_res[el]) + npt.assert_array_equal(result[el], exp_res[el]) def test_tuple_parser_no_inputs(self): tuple_values = "(1.1,22.22,333.333),[2.2,33.33,444.444]" @@ -56,12 +67,24 @@ class TestingMeanScaleGetter(unittest.TestCase): exp_res = [np.array([1.1, 22.22, 333.333]), np.array([2.2, 33.33, 444.444])] for i in range(0, len(exp_res)): - np.array_equal(result[i], exp_res[i]) + npt.assert_array_equal(result[i], exp_res[i]) - def test_tuple_parser_error(self): + def test_tuple_parser_error_mixed_with_and_without_name(self): tuple_values = "(1.1,22.22,333.333),data[2.2,33.33,444.444]" self.assertRaises(Error, parse_tuple_pairs, tuple_values) + def test_tuple_parser_error_mixed_with_and_without_name_1(self): + tuple_values = "data(1.1,22.22,333.333),[2.2,33.33,444.444]" + self.assertRaises(Error, parse_tuple_pairs, tuple_values) + + def test_tuple_parser_error_mixed_with_and_without_name_digits(self): + tuple_values = "(0.1,22.22,333.333),0448[2.2,33.33,444.444]" + self.assertRaises(Error, parse_tuple_pairs, tuple_values) + + def test_tuple_parser_error_mixed_with_and_without_name_digits_1(self): + tuple_values = "447(1.1,22.22,333.333),[2.2,33.33,444.444]" + self.assertRaises(Error, parse_tuple_pairs, tuple_values) + def test_mean_scale_no_input(self): mean_values = "data(1.1,22.22,333.333)" scale_values = "info[1.1,22.22,333.333]" @@ -79,7 +102,7 @@ class TestingMeanScaleGetter(unittest.TestCase): for input in exp_res.keys(): for key in exp_res[input].keys(): if type(exp_res[input][key]) is np.ndarray: - np.array_equal(exp_res[input][key], result[input][key]) + npt.assert_array_equal(exp_res[input][key], result[input][key]) else: self.assertEqual(exp_res[input][key], result[input][key]) @@ -100,7 +123,7 @@ class TestingMeanScaleGetter(unittest.TestCase): for input in exp_res.keys(): for key in exp_res[input].keys(): if type(exp_res[input][key]) is np.ndarray: - np.array_equal(exp_res[input][key], result[input][key]) + npt.assert_array_equal(exp_res[input][key], result[input][key]) else: self.assertEqual(exp_res[input][key], result[input][key]) @@ -116,7 +139,7 @@ class TestingMeanScaleGetter(unittest.TestCase): for input in exp_res.keys(): for key in exp_res[input].keys(): if type(exp_res[input][key]) is np.ndarray: - np.array_equal(exp_res[input][key], result[input][key]) + npt.assert_array_equal(exp_res[input][key], result[input][key]) else: self.assertEqual(exp_res[input][key], result[input][key]) @@ -132,7 +155,7 @@ class TestingMeanScaleGetter(unittest.TestCase): for input in exp_res.keys(): for key in exp_res[input].keys(): if type(exp_res[input][key]) is np.ndarray: - np.array_equal(exp_res[input][key], result[input][key]) + npt.assert_array_equal(exp_res[input][key], result[input][key]) else: self.assertEqual(exp_res[input][key], result[input][key]) @@ -151,7 +174,7 @@ class TestingMeanScaleGetter(unittest.TestCase): for i in range(len(exp_res)): for j in range(len(exp_res[i])): if type(exp_res[i][j]) is np.ndarray: - np.array_equal(exp_res[i][j], result[i][j]) + npt.assert_array_equal(exp_res[i][j], result[i][j]) else: self.assertEqual(exp_res[i][j], result[i][j]) @@ -170,7 +193,7 @@ class TestingMeanScaleGetter(unittest.TestCase): for input in exp_res.keys(): for key in exp_res[input].keys(): if type(exp_res[input][key]) is np.ndarray: - np.array_equal(exp_res[input][key], result[input][key]) + npt.assert_array_equal(exp_res[input][key], result[input][key]) else: self.assertEqual(exp_res[input][key], result[input][key]) @@ -193,7 +216,7 @@ class TestingMeanScaleGetter(unittest.TestCase): for input in exp_res.keys(): for key in exp_res[input].keys(): if type(exp_res[input][key]) is np.ndarray: - np.array_equal(exp_res[input][key], result[input][key]) + npt.assert_array_equal(exp_res[input][key], result[input][key]) else: self.assertEqual(exp_res[input][key], result[input][key]) @@ -216,7 +239,7 @@ class TestingMeanScaleGetter(unittest.TestCase): for input in exp_res.keys(): for key in exp_res[input].keys(): if type(exp_res[input][key]) is np.ndarray: - np.array_equal(exp_res[input][key], result[input][key]) + npt.assert_array_equal(exp_res[input][key], result[input][key]) else: self.assertEqual(exp_res[input][key], result[input][key]) @@ -235,7 +258,7 @@ class TestingMeanScaleGetter(unittest.TestCase): for input in exp_res.keys(): for key in exp_res[input].keys(): if type(exp_res[input][key]) is np.ndarray: - np.array_equal(exp_res[input][key], result[input][key]) + npt.assert_array_equal(exp_res[input][key], result[input][key]) else: self.assertEqual(exp_res[input][key], result[input][key]) @@ -258,7 +281,7 @@ class TestingMeanScaleGetter(unittest.TestCase): for i in range(len(exp_res)): for j in range(len(exp_res[i])): if type(exp_res[i][j]) is np.ndarray: - np.array_equal(exp_res[i][j], result[i][j]) + npt.assert_array_equal(exp_res[i][j], result[i][j]) else: self.assertEqual(exp_res[i][j], result[i][j]) @@ -279,7 +302,7 @@ class TestingMeanScaleGetter(unittest.TestCase): for input in exp_res.keys(): for key in exp_res[input].keys(): if type(exp_res[input][key]) is np.ndarray: - np.array_equal(exp_res[input][key], result[input][key]) + npt.assert_array_equal(exp_res[input][key], result[input][key]) else: self.assertEqual(exp_res[input][key], result[input][key]) @@ -301,7 +324,7 @@ class TestingMeanScaleGetter(unittest.TestCase): for input in exp_res.keys(): for key in exp_res[input].keys(): if type(exp_res[input][key]) is np.ndarray: - np.array_equal(exp_res[input][key], result[input][key]) + npt.assert_array_equal(exp_res[input][key], result[input][key]) else: self.assertEqual(exp_res[input][key], result[input][key]) @@ -323,7 +346,7 @@ class TestingMeanScaleGetter(unittest.TestCase): for input in exp_res.keys(): for key in exp_res[input].keys(): if type(exp_res[input][key]) is np.ndarray: - np.array_equal(exp_res[input][key], result[input][key]) + npt.assert_array_equal(exp_res[input][key], result[input][key]) else: self.assertEqual(exp_res[input][key], result[input][key]) @@ -343,7 +366,7 @@ class TestingMeanScaleGetter(unittest.TestCase): ] for i in range(0, len(exp_res)): for j in range(0, len(exp_res[i])): - np.array_equal(exp_res[i][j], result[i][j]) + npt.assert_array_equal(exp_res[i][j], result[i][j]) def test_scale_do_not_match_input(self): scale_values = parse_tuple_pairs("input_not_present(255),input2(255)") @@ -355,6 +378,20 @@ class TestingMeanScaleGetter(unittest.TestCase): mean_values = parse_tuple_pairs("input_not_present(255),input2(255)") self.assertRaises(Error, get_mean_scale_dictionary, mean_values, scale_values, "input1,input2") + def test_values_match_input_name(self): + # to be sure that we correctly processes complex names + res_values = parse_tuple_pairs("input255(255),input255.0(255.0),multi-dotted.input.3.(255,128,64)") + exp_res = {'input255': np.array([255.0]), + 'input255.0': np.array([255.0]), + 'multi-dotted.input.3.': np.array([255., 128., 64.])} + self.assertEqual(len(exp_res), len(res_values)) + for i, j in zip(exp_res, res_values): + self.assertEqual(i, j) + npt.assert_array_equal(exp_res[i], res_values[j]) + + def test_input_without_values(self): + self.assertRaises(Error, parse_tuple_pairs, "input1,input2") + class TestSingleTupleParsing(unittest.TestCase): def test_get_values_ideal(self): values = "(1.11, 22.22, 333.333)" @@ -425,7 +462,7 @@ class TestShapesParsing(unittest.TestCase): exp_res = {'inp1': np.array([1, 22, 333, 123]), 'inp2': np.array([-1, 45, 7, 1])} self.assertEqual(list(exp_res.keys()), list(result.keys())) for i in exp_res.keys(): - np.testing.assert_array_equal(result[i], exp_res[i]) + npt.assert_array_equal(result[i], exp_res[i]) def test_get_shapes_several_inputs_several_shapes2(self): # shapes specified using --input command line parameter and no values @@ -434,13 +471,13 @@ class TestShapesParsing(unittest.TestCase): exp_res = {'inp1': np.array([1, 22, 333, 123]), 'inp2': np.array([-1, 45, 7, 1])} self.assertEqual(list(exp_res.keys()), list(result.keys())) for i in exp_res.keys(): - np.testing.assert_array_equal(result[i], exp_res[i]) + npt.assert_array_equal(result[i], exp_res[i]) placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None) placeholder_values_ref = {} input_node_names_ref = "inp1,inp2" self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys())) for i in placeholder_values_ref.keys(): - np.testing.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) + npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) def test_get_shapes_several_inputs_several_shapes3(self): # shapes and value for freezing specified using --input command line parameter @@ -449,13 +486,13 @@ class TestShapesParsing(unittest.TestCase): exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array([5])} self.assertEqual(list(exp_res.keys()), list(result.keys())) for i in exp_res.keys(): - np.testing.assert_array_equal(result[i], exp_res[i]) + npt.assert_array_equal(result[i], exp_res[i]) placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None) placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])} input_node_names_ref = "inp1,inp2,inp3" self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys())) for i in placeholder_values_ref.keys(): - np.testing.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) + npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) def test_get_shapes_several_inputs_several_shapes4(self): # shapes specified using --input_shape and values for freezing using --input command line parameter @@ -465,13 +502,13 @@ class TestShapesParsing(unittest.TestCase): exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array([5])} self.assertEqual(list(exp_res.keys()), list(result.keys())) for i in exp_res.keys(): - np.testing.assert_array_equal(result[i], exp_res[i]) + npt.assert_array_equal(result[i], exp_res[i]) placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None) placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'])} input_node_names_ref = "inp1,inp2,inp3" self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys())) for i in placeholder_values_ref.keys(): - np.testing.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) + npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) self.assertEqual(input_node_names_ref, input_node_names_res) def test_get_shapes_several_inputs_several_shapes5(self): @@ -484,14 +521,14 @@ class TestShapesParsing(unittest.TestCase): exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array([5])} self.assertEqual(list(exp_res.keys()), list(result.keys())) for i in exp_res.keys(): - np.testing.assert_array_equal(result[i], exp_res[i]) + npt.assert_array_equal(result[i], exp_res[i]) placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, argv_freeze_placeholder_with_value) placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': np.array(['1.0', '1.0', '2.0', '3.0', '5.0'],), 'inp2': np.array(['5.0', '7.0', '3.0']), 'inp4': np.array(['100.0', '200.0'])} input_node_names_ref = "inp1,inp2,inp3" self.assertEqual(sorted(list(placeholder_values_res.keys())), sorted(list(placeholder_values_ref.keys()))) for i in placeholder_values_ref.keys(): - np.testing.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) + npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) self.assertEqual(input_node_names_ref, input_node_names_res) def test_get_shapes_several_inputs_several_shapes6(self): @@ -501,12 +538,12 @@ class TestShapesParsing(unittest.TestCase): exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array(False).shape} self.assertEqual(list(exp_res.keys()), list(result.keys())) for i in exp_res.keys(): - np.testing.assert_array_equal(result[i], exp_res[i]) + npt.assert_array_equal(result[i], exp_res[i]) placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None) placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': False} self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys())) for i in placeholder_values_ref.keys(): - np.testing.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) + npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) def test_get_shapes_several_inputs_several_shapes7(self): # 0D shape and value for freezing specified using --input command line parameter @@ -515,12 +552,12 @@ class TestShapesParsing(unittest.TestCase): exp_res = {'inp1': np.array([3, 1]), 'inp2': np.array([3, 2, 3]), 'inp3': np.array(False).shape} self.assertEqual(list(exp_res.keys()), list(result.keys())) for i in exp_res.keys(): - np.testing.assert_array_equal(result[i], exp_res[i]) + npt.assert_array_equal(result[i], exp_res[i]) placeholder_values_res, input_node_names_res = get_freeze_placeholder_values(argv_input, None) placeholder_values_ref = {'inp1': np.array(['1.0', '2.0', '3.0']), 'inp3': True} self.assertEqual(list(placeholder_values_res.keys()), list(placeholder_values_ref.keys())) for i in placeholder_values_ref.keys(): - np.testing.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) + npt.assert_array_equal(placeholder_values_res[i], placeholder_values_ref[i]) def test_get_shapes_and_data_types1(self): argv_input = "inp1[3 1]->[1.0 2.0 3.0],inp2[3 2 3]{i32},inp3[5]{f32}->[1.0 1.0 2.0 3.0 5.0]" @@ -529,7 +566,7 @@ class TestShapesParsing(unittest.TestCase): ref_result_data_types = {'inp2': np.int32, 'inp3': np.float32} self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys())) for i in ref_result_shapes.keys(): - np.testing.assert_array_equal(result_shapes[i], ref_result_shapes[i]) + npt.assert_array_equal(result_shapes[i], ref_result_shapes[i]) self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys())) for i in ref_result_data_types.keys(): np.testing.assert_equal(result_data_types[i], ref_result_data_types[i]) @@ -541,7 +578,7 @@ class TestShapesParsing(unittest.TestCase): ref_result_data_types = {'inp2': np.int32, '0:inp3': np.float32} self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys())) for i in ref_result_shapes.keys(): - np.testing.assert_array_equal(result_shapes[i], ref_result_shapes[i]) + npt.assert_array_equal(result_shapes[i], ref_result_shapes[i]) self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys())) for i in ref_result_data_types.keys(): np.testing.assert_equal(result_data_types[i], ref_result_data_types[i]) @@ -553,7 +590,7 @@ class TestShapesParsing(unittest.TestCase): ref_result_data_types = {'inp2': np.int32, 'inp3:4': np.float32} self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys())) for i in ref_result_shapes.keys(): - np.testing.assert_array_equal(result_shapes[i], ref_result_shapes[i]) + npt.assert_array_equal(result_shapes[i], ref_result_shapes[i]) self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys())) for i in ref_result_data_types.keys(): np.testing.assert_equal(result_data_types[i], ref_result_data_types[i]) @@ -566,7 +603,7 @@ class TestShapesParsing(unittest.TestCase): ref_result_data_types = {} self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys())) for i in ref_result_shapes.keys(): - np.testing.assert_array_equal(result_shapes[i], ref_result_shapes[i]) + npt.assert_array_equal(result_shapes[i], ref_result_shapes[i]) self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys())) for i in ref_result_data_types.keys(): np.testing.assert_equal(result_data_types[i], ref_result_data_types[i]) @@ -579,7 +616,7 @@ class TestShapesParsing(unittest.TestCase): ref_result_data_types = {} self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys())) for i in ref_result_shapes.keys(): - np.testing.assert_array_equal(result_shapes[i], ref_result_shapes[i]) + npt.assert_array_equal(result_shapes[i], ref_result_shapes[i]) self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys())) for i in ref_result_data_types.keys(): np.testing.assert_equal(result_data_types[i], ref_result_data_types[i]) @@ -592,7 +629,7 @@ class TestShapesParsing(unittest.TestCase): ref_result_data_types = {'placeholder1': np.int32, 'placeholder3': np.int32} self.assertEqual(list(ref_result_shapes.keys()), list(result_shapes.keys())) for i in ref_result_shapes.keys(): - np.testing.assert_array_equal(result_shapes[i], ref_result_shapes[i]) + npt.assert_array_equal(result_shapes[i], ref_result_shapes[i]) self.assertEqual(list(ref_result_data_types.keys()), list(result_data_types.keys())) for i in ref_result_data_types.keys(): np.testing.assert_equal(result_data_types[i], ref_result_data_types[i]) @@ -641,28 +678,28 @@ class TestShapesParsing(unittest.TestCase): exp_res = {'inp1': np.array([1, 22, 333, 123])} self.assertEqual(list(exp_res.keys()), list(result.keys())) for i in exp_res.keys(): - np.testing.assert_array_equal(result[i], exp_res[i]) + npt.assert_array_equal(result[i], exp_res[i]) def test_get_shapes_no_input_no_shape(self): argv_input = "" input_shapes = "" result, _ = get_placeholder_shapes(argv_input, input_shapes) exp_res = np.array([None]) - np.testing.assert_array_equal(result, exp_res) + npt.assert_array_equal(result, exp_res) def test_get_shapes_no_input_one_shape(self): argv_input = "" input_shapes = "(12,4,1)" result, _ = get_placeholder_shapes(argv_input, input_shapes) exp_res = np.array([12, 4, 1]) - np.testing.assert_array_equal(result, exp_res) + npt.assert_array_equal(result, exp_res) def test_get_shapes_no_input_one_shape2(self): argv_input = "" input_shapes = "[12,4,1]" result, _ = get_placeholder_shapes(argv_input, input_shapes) exp_res = np.array([12, 4, 1]) - np.testing.assert_array_equal(result, exp_res) + npt.assert_array_equal(result, exp_res) def test_get_shapes_no_input_two_shapes(self): argv_input = "" @@ -676,7 +713,7 @@ class TestShapesParsing(unittest.TestCase): exp_res = {'inp1': np.array([None])} self.assertEqual(list(exp_res.keys()), list(result.keys())) for i in exp_res.keys(): - np.testing.assert_array_equal(result[i], exp_res[i]) + npt.assert_array_equal(result[i], exp_res[i]) def test_get_shapes_one_input_wrong_shape8(self): argv_input = "inp1" @@ -735,7 +772,7 @@ class TestShapesParsing(unittest.TestCase): exp_res = {'inp1': np.array([-1, 4, 1]), 'inp2': np.array([4, 6, 8])} self.assertEqual(list(exp_res.keys()), list(result.keys())) for i in exp_res.keys(): - np.testing.assert_array_equal(result[i], exp_res[i]) + npt.assert_array_equal(result[i], exp_res[i]) def test_get_shapes_one_input_first_neg_shape_not_one(self): argv_input = "inp1"