Fix skipping incorrect names in scale/mean values (#535)
* Fix skipping incorrect names in scale/mean values * removed inappropriate comment in cli_parser.py
This commit is contained in:
@@ -1011,7 +1011,7 @@ def get_mean_scale_dictionary(mean_values, scale_values, argv_input: str):
|
||||
Returns
|
||||
-------
|
||||
The function returns a dictionary e.g.
|
||||
mean = { 'data: np.array, 'info': np.array }, scale = { 'data: np.array, 'info': np.array }, input = "data, info" ->
|
||||
mean = { 'data': np.array, 'info': np.array }, scale = { 'data': np.array, 'info': np.array }, input = "data, info" ->
|
||||
{ 'data': { 'mean': np.array, 'scale': np.array }, 'info': { 'mean': np.array, 'scale': np.array } }
|
||||
|
||||
"""
|
||||
@@ -1032,6 +1032,17 @@ def get_mean_scale_dictionary(mean_values, scale_values, argv_input: str):
|
||||
if type(mean_values) is dict and type(scale_values) is dict:
|
||||
if not mean_values and not scale_values:
|
||||
return res
|
||||
|
||||
for inp_scale in scale_values.keys():
|
||||
if inp_scale not in inputs:
|
||||
raise Error("Specified scale_values name '{}' do not match to any of inputs: {}. "
|
||||
"Please set 'scale_values' that correspond to values from input.".format(inp_scale, inputs))
|
||||
|
||||
for inp_mean in mean_values.keys():
|
||||
if inp_mean not in inputs:
|
||||
raise Error("Specified mean_values name '{}' do not match to any of inputs: {}. "
|
||||
"Please set 'mean_values' that correspond to values from input.".format(inp_mean, inputs))
|
||||
|
||||
for inp in inputs:
|
||||
inp, port = split_node_in_port(inp)
|
||||
if inp in mean_values or inp in scale_values:
|
||||
@@ -1105,7 +1116,7 @@ def get_mean_scale_dictionary(mean_values, scale_values, argv_input: str):
|
||||
}
|
||||
)
|
||||
return res
|
||||
# mean and scale are specified without inputs, return list, order is not guaranteed (?)
|
||||
# mean and/or scale are specified without inputs
|
||||
return list(zip_longest(mean_values, scale_values))
|
||||
|
||||
|
||||
|
||||
@@ -345,6 +345,15 @@ class TestingMeanScaleGetter(unittest.TestCase):
|
||||
for j in range(0, len(exp_res[i])):
|
||||
np.array_equal(exp_res[i][j], result[i][j])
|
||||
|
||||
def test_scale_do_not_match_input(self):
|
||||
scale_values = parse_tuple_pairs("input_not_present(255),input2(255)")
|
||||
mean_values = parse_tuple_pairs("input1(255),input2(255)")
|
||||
self.assertRaises(Error, get_mean_scale_dictionary, mean_values, scale_values, "input1,input2")
|
||||
|
||||
def test_mean_do_not_match_input(self):
|
||||
scale_values = parse_tuple_pairs("input1(255),input2(255)")
|
||||
mean_values = parse_tuple_pairs("input_not_present(255),input2(255)")
|
||||
self.assertRaises(Error, get_mean_scale_dictionary, mean_values, scale_values, "input1,input2")
|
||||
|
||||
class TestSingleTupleParsing(unittest.TestCase):
|
||||
def test_get_values_ideal(self):
|
||||
|
||||
Reference in New Issue
Block a user