Fix unit tests for select layer. (#638)
* Fix unit tests for select layer.
This commit is contained in:
parent
f1811ad060
commit
eefaf56075
@ -49,14 +49,14 @@ class Select(Op):
|
|||||||
node.out_port(0).data.set_shape(broadcast_shape(a_shape, b_shape))
|
node.out_port(0).data.set_shape(broadcast_shape(a_shape, b_shape))
|
||||||
# Case with unknown condition
|
# Case with unknown condition
|
||||||
if condition_value is not None:
|
if condition_value is not None:
|
||||||
|
output_value = np.where(condition_value, resulting_tensors[0], resulting_tensors[1])
|
||||||
if condition_value.size != 1:
|
if condition_value.size != 1:
|
||||||
output_value = np.where(condition_value, resulting_tensors[0], resulting_tensors[1])
|
|
||||||
if np.any(output_value == None):
|
if np.any(output_value == None):
|
||||||
# If any element of output value is None that means that we use the value from 'then' or 'else' tensor
|
# If any element of output value is None that means that we use the value from 'then' or 'else' tensor
|
||||||
# which is not defined, this means that we cannot perform value propagation.
|
# which is not defined, this means that we cannot perform value propagation.
|
||||||
output_value = None
|
output_value = None
|
||||||
else:
|
else:
|
||||||
output_value = resulting_tensors[not np.bool(condition_value.item(0))]
|
output_value = np.array(output_value, dtype=resulting_tensors[not np.bool(condition_value.item(0))].dtype)
|
||||||
|
|
||||||
if output_value is not None:
|
if output_value is not None:
|
||||||
node.out_port(0).data.set_value(np.array(output_value))
|
node.out_port(0).data.set_value(np.array(output_value))
|
||||||
|
@ -151,31 +151,52 @@ class TestSelect(unittest.TestCase):
|
|||||||
self.assertTrue(flag, resp)
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
@generate(*[
|
@generate(*[
|
||||||
([5, 6], [5, 6], [5, 6], np.ones(np.array([5, 6]), dtype=np.float),
|
([5, 6], [5, 6], [5, 6], [5, 6], lambda x: np.ones(x, dtype=np.float),
|
||||||
np.zeros([5, 6], dtype=np.float), np.ones([5, 6], dtype=np.float),
|
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
|
||||||
np.ones([5, 6], dtype=np.float)),
|
lambda x: np.ones(x, dtype=np.float)),
|
||||||
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.ones([15, 3, 5], dtype=np.float),
|
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.ones(x, dtype=np.float),
|
||||||
np.zeros([15, 3, 5], dtype=np.float), np.ones([15, 3, 5], dtype=np.float),
|
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
|
||||||
np.ones([15, 3, 5], dtype=np.float)),
|
lambda x: np.ones(x, dtype=np.float)),
|
||||||
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.ones([15, 3, 5], dtype=np.float),
|
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.ones(x, dtype=np.float),
|
||||||
None, np.ones([15, 3, 5], dtype=np.float), np.ones([15, 3, 5], dtype=np.float)),
|
lambda x: None, lambda x: np.ones(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float)),
|
||||||
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.ones([15, 3, 5], dtype=np.float),
|
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.ones(x, dtype=np.float),
|
||||||
np.ones([15, 3, 5], dtype=np.float), None, None),
|
lambda x: np.ones(x, dtype=np.float), lambda x: None, lambda x: None),
|
||||||
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.zeros([15, 3, 5], dtype=np.float),
|
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.zeros(x, dtype=np.float),
|
||||||
None, np.ones([15, 3, 5], dtype=np.float), None),
|
lambda x: None, lambda x: np.ones(x, dtype=np.float), lambda x: None),
|
||||||
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.zeros([15, 3, 5], dtype=np.float),
|
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.zeros(x, dtype=np.float),
|
||||||
np.ones([15, 3, 5], dtype=np.float), None, np.ones([15, 3, 5], dtype=np.float)),
|
lambda x: np.ones(x, dtype=np.float), lambda x: None, lambda x: np.ones(x, dtype=np.float)),
|
||||||
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.array([True], np.bool),
|
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.array([True], np.bool),
|
||||||
np.zeros([15, 3, 5], dtype=np.float), np.ones([15, 3, 5], dtype=np.float),
|
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
|
||||||
np.ones([15, 3, 5], dtype=np.float)),
|
lambda x: np.ones(x, dtype=np.float)),
|
||||||
([15, 3, 5], [15, 1, 5], [15, 3, 5], np.array([False], np.bool),
|
([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.array([False], np.bool),
|
||||||
np.zeros([15, 3, 5], dtype=np.float), np.ones([15, 3, 5], dtype=np.float),
|
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
|
||||||
np.zeros([15, 3, 5], dtype=np.float)),
|
lambda x: np.zeros(x, dtype=np.float)),
|
||||||
|
([2, 3, 4, 5], [2, 3, 4, 5], [], [2, 3, 4, 5], lambda x: np.ones(x, dtype=np.float),
|
||||||
|
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
|
||||||
|
lambda x: np.ones(x, dtype=np.float)),
|
||||||
|
([2, 3, 4, 5], [2, 3, 4, 5], [5], [2, 3, 4, 5], lambda x: np.ones(x, dtype=np.float),
|
||||||
|
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
|
||||||
|
lambda x: np.ones(x, dtype=np.float)),
|
||||||
|
([2, 3, 1, 1], [2, 1, 1, 5], [2, 3, 4, 5], [2, 3, 4, 5], lambda x: np.ones(x, dtype=np.float),
|
||||||
|
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
|
||||||
|
lambda x: np.ones(x, dtype=np.float)),
|
||||||
|
([2, 3, 4, 1], [2, 1, 1, 5], [1, 3, 1, 5], [2, 3, 4, 5], lambda x: np.ones(x, dtype=np.float),
|
||||||
|
lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float),
|
||||||
|
lambda x: np.ones(x, dtype=np.float)),
|
||||||
])
|
])
|
||||||
def test_select_infer_condition_with_value(self, else_data_shape, than_data_shape, select_output_shape,
|
def test_select_infer_condition_with_value(self, condition_shape, else_data_shape, than_data_shape, select_output_shape,
|
||||||
condition_value, else_value, than_value, output_value):
|
condition_value, else_value, than_value, output_value):
|
||||||
|
"""
|
||||||
|
Unit tests generator can sporadic throw exception if we try
|
||||||
|
to run generator with call numpy array generation functions.
|
||||||
|
So we need to use lambda function for escape the problem.
|
||||||
|
"""
|
||||||
|
condition_value = condition_value(condition_shape)
|
||||||
|
else_value = else_value(else_data_shape)
|
||||||
|
than_value = than_value(than_data_shape)
|
||||||
|
output_value = output_value(select_output_shape)
|
||||||
graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges,
|
graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges,
|
||||||
update_nodes_attributes=[('condition_data', {'shape': np.array(select_output_shape),
|
update_nodes_attributes=[('condition_data', {'shape': np.array(condition_shape),
|
||||||
'value': condition_value}),
|
'value': condition_value}),
|
||||||
('else_data', {'shape': np.array(else_data_shape),
|
('else_data', {'shape': np.array(else_data_shape),
|
||||||
'value': else_value}),
|
'value': else_value}),
|
||||||
@ -188,7 +209,7 @@ class TestSelect(unittest.TestCase):
|
|||||||
|
|
||||||
graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges,
|
graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges,
|
||||||
update_nodes_attributes=[
|
update_nodes_attributes=[
|
||||||
('condition_data', {'shape': np.array(else_data_shape),
|
('condition_data', {'shape': np.array(condition_shape),
|
||||||
'value': condition_value}),
|
'value': condition_value}),
|
||||||
('else_data', {'shape': np.array(else_data_shape),
|
('else_data', {'shape': np.array(else_data_shape),
|
||||||
'value': else_value}),
|
'value': else_value}),
|
||||||
|
Loading…
Reference in New Issue
Block a user