Fix unit tests for select layer. (#638)

* Fix unit tests for select layer.
This commit is contained in:
iliya mironov 2020-06-08 18:39:40 +03:00 committed by GitHub
parent f1811ad060
commit eefaf56075
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 25 deletions

View File

@ -49,14 +49,14 @@ class Select(Op):
node.out_port(0).data.set_shape(broadcast_shape(a_shape, b_shape))
# Case with unknown condition
if condition_value is not None:
output_value = np.where(condition_value, resulting_tensors[0], resulting_tensors[1])
if condition_value.size != 1:
output_value = np.where(condition_value, resulting_tensors[0], resulting_tensors[1])
if np.any(output_value == None):
# If any element of output value is None that means that we use the value from 'then' or 'else' tensor
# which is not defined, this means that we cannot perform value propagation.
output_value = None
else:
output_value = resulting_tensors[not np.bool(condition_value.item(0))]
output_value = np.array(output_value, dtype=resulting_tensors[not np.bool(condition_value.item(0))].dtype)
if output_value is not None:
node.out_port(0).data.set_value(np.array(output_value))

View File

@ -151,31 +151,52 @@ class TestSelect(unittest.TestCase):
self.assertTrue(flag, resp)
@generate(*[
([5, 6], [5, 6], [5, 6], np.ones(np.array([5, 6]), dtype=np.float),
np.zeros([5, 6], dtype=np.float), np.ones([5, 6], dtype=np.float),
np.ones([5, 6], dtype=np.float)),
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.ones([15, 3, 5], dtype=np.float),
np.zeros([15, 3, 5], dtype=np.float), np.ones([15, 3, 5], dtype=np.float),
np.ones([15, 3, 5], dtype=np.float)),
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.ones([15, 3, 5], dtype=np.float),
None, np.ones([15, 3, 5], dtype=np.float), np.ones([15, 3, 5], dtype=np.float)),
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.ones([15, 3, 5], dtype=np.float),
np.ones([15, 3, 5], dtype=np.float), None, None),
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.zeros([15, 3, 5], dtype=np.float),
None, np.ones([15, 3, 5], dtype=np.float), None),
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.zeros([15, 3, 5], dtype=np.float),
np.ones([15, 3, 5], dtype=np.float), None, np.ones([15, 3, 5], dtype=np.float)),
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.array([True], np.bool),
np.zeros([15, 3, 5], dtype=np.float), np.ones([15, 3, 5], dtype=np.float),
np.ones([15, 3, 5], dtype=np.float)),
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.array([False], np.bool),
np.zeros([15, 3, 5], dtype=np.float), np.ones([15, 3, 5], dtype=np.float),
np.zeros([15, 3, 5], dtype=np.float)),
([5, 6], [5, 6], [5, 6], [5, 6], lambda x: np.ones(x, dtype=np.float),
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
lambda x: np.ones(x, dtype=np.float)),
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.ones(x, dtype=np.float),
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
lambda x: np.ones(x, dtype=np.float)),
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.ones(x, dtype=np.float),
lambda x: None, lambda x: np.ones(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float)),
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.ones(x, dtype=np.float),
lambda x: np.ones(x, dtype=np.float), lambda x: None, lambda x: None),
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.zeros(x, dtype=np.float),
lambda x: None, lambda x: np.ones(x, dtype=np.float), lambda x: None),
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.zeros(x, dtype=np.float),
lambda x: np.ones(x, dtype=np.float), lambda x: None, lambda x: np.ones(x, dtype=np.float)),
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.array([True], np.bool),
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
lambda x: np.ones(x, dtype=np.float)),
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.array([False], np.bool),
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
lambda x: np.zeros(x, dtype=np.float)),
([2, 3, 4, 5], [2, 3, 4, 5], [], [2, 3, 4, 5], lambda x: np.ones(x, dtype=np.float),
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
lambda x: np.ones(x, dtype=np.float)),
([2, 3, 4, 5], [2, 3, 4, 5], [5], [2, 3, 4, 5], lambda x: np.ones(x, dtype=np.float),
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
lambda x: np.ones(x, dtype=np.float)),
([2, 3, 1, 1], [2, 1, 1, 5], [2, 3, 4, 5], [2, 3, 4, 5], lambda x: np.ones(x, dtype=np.float),
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
lambda x: np.ones(x, dtype=np.float)),
([2, 3, 4, 1], [2, 1, 1, 5], [1, 3, 1, 5], [2, 3, 4, 5], lambda x: np.ones(x, dtype=np.float),
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
lambda x: np.ones(x, dtype=np.float)),
])
def test_select_infer_condition_with_value(self, else_data_shape, than_data_shape, select_output_shape,
def test_select_infer_condition_with_value(self, condition_shape, else_data_shape, than_data_shape, select_output_shape,
condition_value, else_value, than_value, output_value):
"""
Unit tests generator can sporadic throw exception if we try
to run generator with call numpy array generation functions.
So we need to use lambda function for escape the problem.
"""
condition_value = condition_value(condition_shape)
else_value = else_value(else_data_shape)
than_value = than_value(than_data_shape)
output_value = output_value(select_output_shape)
graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges,
update_nodes_attributes=[('condition_data', {'shape': np.array(select_output_shape),
update_nodes_attributes=[('condition_data', {'shape': np.array(condition_shape),
'value': condition_value}),
('else_data', {'shape': np.array(else_data_shape),
'value': else_value}),
@ -188,7 +209,7 @@ class TestSelect(unittest.TestCase):
graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges,
update_nodes_attributes=[
('condition_data', {'shape': np.array(else_data_shape),
('condition_data', {'shape': np.array(condition_shape),
'value': condition_value}),
('else_data', {'shape': np.array(else_data_shape),
'value': else_value}),