Fix skipping incorrect names in scale/mean values (#535)

* Fix skipping incorrect names in scale/mean values

* removed inappropriate comment in cli_parser.py
This commit is contained in:
Pavel Esir
2020-05-27 14:53:50 +03:00
committed by GitHub
parent d24132912e
commit e337350cc1
2 changed files with 22 additions and 2 deletions

View File

@@ -1011,7 +1011,7 @@ def get_mean_scale_dictionary(mean_values, scale_values, argv_input: str):
Returns
-------
The function returns a dictionary e.g.
mean = { 'data: np.array, 'info': np.array }, scale = { 'data: np.array, 'info': np.array }, input = "data, info" ->
mean = { 'data': np.array, 'info': np.array }, scale = { 'data': np.array, 'info': np.array }, input = "data, info" ->
{ 'data': { 'mean': np.array, 'scale': np.array }, 'info': { 'mean': np.array, 'scale': np.array } }
"""
@@ -1032,6 +1032,17 @@ def get_mean_scale_dictionary(mean_values, scale_values, argv_input: str):
if type(mean_values) is dict and type(scale_values) is dict:
if not mean_values and not scale_values:
return res
for inp_scale in scale_values.keys():
if inp_scale not in inputs:
raise Error("Specified scale_values name '{}' do not match to any of inputs: {}. "
"Please set 'scale_values' that correspond to values from input.".format(inp_scale, inputs))
for inp_mean in mean_values.keys():
if inp_mean not in inputs:
raise Error("Specified mean_values name '{}' do not match to any of inputs: {}. "
"Please set 'mean_values' that correspond to values from input.".format(inp_mean, inputs))
for inp in inputs:
inp, port = split_node_in_port(inp)
if inp in mean_values or inp in scale_values:
@@ -1105,7 +1116,7 @@ def get_mean_scale_dictionary(mean_values, scale_values, argv_input: str):
}
)
return res
# mean and scale are specified without inputs, return list, order is not guaranteed (?)
# mean and/or scale are specified without inputs
return list(zip_longest(mean_values, scale_values))

View File

@@ -345,6 +345,15 @@ class TestingMeanScaleGetter(unittest.TestCase):
for j in range(0, len(exp_res[i])):
np.array_equal(exp_res[i][j], result[i][j])
def test_scale_do_not_match_input(self):
scale_values = parse_tuple_pairs("input_not_present(255),input2(255)")
mean_values = parse_tuple_pairs("input1(255),input2(255)")
self.assertRaises(Error, get_mean_scale_dictionary, mean_values, scale_values, "input1,input2")
def test_mean_do_not_match_input(self):
scale_values = parse_tuple_pairs("input1(255),input2(255)")
mean_values = parse_tuple_pairs("input_not_present(255),input2(255)")
self.assertRaises(Error, get_mean_scale_dictionary, mean_values, scale_values, "input1,input2")
class TestSingleTupleParsing(unittest.TestCase):
def test_get_values_ideal(self):