Fix unit tests for select layer. (#638)
* Fix unit tests for select layer.
This commit is contained in:
parent
f1811ad060
commit
eefaf56075
@ -49,14 +49,14 @@ class Select(Op):
|
||||
node.out_port(0).data.set_shape(broadcast_shape(a_shape, b_shape))
|
||||
# Case with unknown condition
|
||||
if condition_value is not None:
|
||||
output_value = np.where(condition_value, resulting_tensors[0], resulting_tensors[1])
|
||||
if condition_value.size != 1:
|
||||
output_value = np.where(condition_value, resulting_tensors[0], resulting_tensors[1])
|
||||
if np.any(output_value == None):
|
||||
# If any element of output value is None that means that we use the value from 'then' or 'else' tensor
|
||||
# which is not defined, this means that we cannot perform value propagation.
|
||||
output_value = None
|
||||
else:
|
||||
output_value = resulting_tensors[not np.bool(condition_value.item(0))]
|
||||
output_value = np.array(output_value, dtype=resulting_tensors[not np.bool(condition_value.item(0))].dtype)
|
||||
|
||||
if output_value is not None:
|
||||
node.out_port(0).data.set_value(np.array(output_value))
|
||||
|
@ -151,31 +151,52 @@ class TestSelect(unittest.TestCase):
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
@generate(*[
|
||||
([5, 6], [5, 6], [5, 6], np.ones(np.array([5, 6]), dtype=np.float),
|
||||
np.zeros([5, 6], dtype=np.float), np.ones([5, 6], dtype=np.float),
|
||||
np.ones([5, 6], dtype=np.float)),
|
||||
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.ones([15, 3, 5], dtype=np.float),
|
||||
np.zeros([15, 3, 5], dtype=np.float), np.ones([15, 3, 5], dtype=np.float),
|
||||
np.ones([15, 3, 5], dtype=np.float)),
|
||||
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.ones([15, 3, 5], dtype=np.float),
|
||||
None, np.ones([15, 3, 5], dtype=np.float), np.ones([15, 3, 5], dtype=np.float)),
|
||||
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.ones([15, 3, 5], dtype=np.float),
|
||||
np.ones([15, 3, 5], dtype=np.float), None, None),
|
||||
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.zeros([15, 3, 5], dtype=np.float),
|
||||
None, np.ones([15, 3, 5], dtype=np.float), None),
|
||||
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.zeros([15, 3, 5], dtype=np.float),
|
||||
np.ones([15, 3, 5], dtype=np.float), None, np.ones([15, 3, 5], dtype=np.float)),
|
||||
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.array([True], np.bool),
|
||||
np.zeros([15, 3, 5], dtype=np.float), np.ones([15, 3, 5], dtype=np.float),
|
||||
np.ones([15, 3, 5], dtype=np.float)),
|
||||
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.array([False], np.bool),
|
||||
np.zeros([15, 3, 5], dtype=np.float), np.ones([15, 3, 5], dtype=np.float),
|
||||
np.zeros([15, 3, 5], dtype=np.float)),
|
||||
([5, 6], [5, 6], [5, 6], [5, 6], lambda x: np.ones(x, dtype=np.float),
|
||||
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
|
||||
lambda x: np.ones(x, dtype=np.float)),
|
||||
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.ones(x, dtype=np.float),
|
||||
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
|
||||
lambda x: np.ones(x, dtype=np.float)),
|
||||
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.ones(x, dtype=np.float),
|
||||
lambda x: None, lambda x: np.ones(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float)),
|
||||
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.ones(x, dtype=np.float),
|
||||
lambda x: np.ones(x, dtype=np.float), lambda x: None, lambda x: None),
|
||||
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.zeros(x, dtype=np.float),
|
||||
lambda x: None, lambda x: np.ones(x, dtype=np.float), lambda x: None),
|
||||
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.zeros(x, dtype=np.float),
|
||||
lambda x: np.ones(x, dtype=np.float), lambda x: None, lambda x: np.ones(x, dtype=np.float)),
|
||||
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.array([True], np.bool),
|
||||
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
|
||||
lambda x: np.ones(x, dtype=np.float)),
|
||||
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.array([False], np.bool),
|
||||
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
|
||||
lambda x: np.zeros(x, dtype=np.float)),
|
||||
([2, 3, 4, 5], [2, 3, 4, 5], [], [2, 3, 4, 5], lambda x: np.ones(x, dtype=np.float),
|
||||
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
|
||||
lambda x: np.ones(x, dtype=np.float)),
|
||||
([2, 3, 4, 5], [2, 3, 4, 5], [5], [2, 3, 4, 5], lambda x: np.ones(x, dtype=np.float),
|
||||
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
|
||||
lambda x: np.ones(x, dtype=np.float)),
|
||||
([2, 3, 1, 1], [2, 1, 1, 5], [2, 3, 4, 5], [2, 3, 4, 5], lambda x: np.ones(x, dtype=np.float),
|
||||
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
|
||||
lambda x: np.ones(x, dtype=np.float)),
|
||||
([2, 3, 4, 1], [2, 1, 1, 5], [1, 3, 1, 5], [2, 3, 4, 5], lambda x: np.ones(x, dtype=np.float),
|
||||
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
|
||||
lambda x: np.ones(x, dtype=np.float)),
|
||||
])
|
||||
def test_select_infer_condition_with_value(self, else_data_shape, than_data_shape, select_output_shape,
|
||||
def test_select_infer_condition_with_value(self, condition_shape, else_data_shape, than_data_shape, select_output_shape,
|
||||
condition_value, else_value, than_value, output_value):
|
||||
"""
|
||||
Unit tests generator can sporadic throw exception if we try
|
||||
to run generator with call numpy array generation functions.
|
||||
So we need to use lambda function for escape the problem.
|
||||
"""
|
||||
condition_value = condition_value(condition_shape)
|
||||
else_value = else_value(else_data_shape)
|
||||
than_value = than_value(than_data_shape)
|
||||
output_value = output_value(select_output_shape)
|
||||
graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges,
|
||||
update_nodes_attributes=[('condition_data', {'shape': np.array(select_output_shape),
|
||||
update_nodes_attributes=[('condition_data', {'shape': np.array(condition_shape),
|
||||
'value': condition_value}),
|
||||
('else_data', {'shape': np.array(else_data_shape),
|
||||
'value': else_value}),
|
||||
@ -188,7 +209,7 @@ class TestSelect(unittest.TestCase):
|
||||
|
||||
graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges,
|
||||
update_nodes_attributes=[
|
||||
('condition_data', {'shape': np.array(else_data_shape),
|
||||
('condition_data', {'shape': np.array(condition_shape),
|
||||
'value': condition_value}),
|
||||
('else_data', {'shape': np.array(else_data_shape),
|
||||
'value': else_value}),
|
||||
|
Loading…
Reference in New Issue
Block a user