diff --git a/model-optimizer/extensions/ops/select.py b/model-optimizer/extensions/ops/select.py index cac32852b63..09d541a2b0a 100644 --- a/model-optimizer/extensions/ops/select.py +++ b/model-optimizer/extensions/ops/select.py @@ -49,14 +49,14 @@ class Select(Op): node.out_port(0).data.set_shape(broadcast_shape(a_shape, b_shape)) # Case with unknown condition if condition_value is not None: + output_value = np.where(condition_value, resulting_tensors[0], resulting_tensors[1]) if condition_value.size != 1: - output_value = np.where(condition_value, resulting_tensors[0], resulting_tensors[1]) if np.any(output_value == None): # If any element of output value is None that means that we use the value from 'then' or 'else' tensor # which is not defined, this means that we cannot perform value propagation. output_value = None else: - output_value = resulting_tensors[not np.bool(condition_value.item(0))] + output_value = np.array(output_value, dtype=resulting_tensors[not np.bool(condition_value.item(0))].dtype) if output_value is not None: node.out_port(0).data.set_value(np.array(output_value)) diff --git a/model-optimizer/extensions/ops/select_test.py b/model-optimizer/extensions/ops/select_test.py index e0e1686a30b..2f504b79d4e 100644 --- a/model-optimizer/extensions/ops/select_test.py +++ b/model-optimizer/extensions/ops/select_test.py @@ -151,31 +151,52 @@ class TestSelect(unittest.TestCase): self.assertTrue(flag, resp) @generate(*[ - ([5, 6], [5, 6], [5, 6], np.ones(np.array([5, 6]), dtype=np.float), - np.zeros([5, 6], dtype=np.float), np.ones([5, 6], dtype=np.float), - np.ones([5, 6], dtype=np.float)), - ([15, 3, 5], [15, 1, 5], [15, 3, 5], np.ones([15, 3, 5], dtype=np.float), - np.zeros([15, 3, 5], dtype=np.float), np.ones([15, 3, 5], dtype=np.float), - np.ones([15, 3, 5], dtype=np.float)), - ([15, 3, 5], [15, 1, 5], [15, 3, 5], np.ones([15, 3, 5], dtype=np.float), - None, np.ones([15, 3, 5], dtype=np.float), np.ones([15, 3, 5], dtype=np.float)), - ([15, 3, 5], [15, 1, 5], [15, 3, 5], np.ones([15, 3, 5], dtype=np.float), - np.ones([15, 3, 5], dtype=np.float), None, None), - ([15, 3, 5], [15, 1, 5], [15, 3, 5], np.zeros([15, 3, 5], dtype=np.float), - None, np.ones([15, 3, 5], dtype=np.float), None), - ([15, 3, 5], [15, 1, 5], [15, 3, 5], np.zeros([15, 3, 5], dtype=np.float), - np.ones([15, 3, 5], dtype=np.float), None, np.ones([15, 3, 5], dtype=np.float)), - ([15, 3, 5], [15, 1, 5], [15, 3, 5], np.array([True], np.bool), - np.zeros([15, 3, 5], dtype=np.float), np.ones([15, 3, 5], dtype=np.float), - np.ones([15, 3, 5], dtype=np.float)), - ([15, 3, 5], [15, 1, 5], [15, 3, 5], np.array([False], np.bool), - np.zeros([15, 3, 5], dtype=np.float), np.ones([15, 3, 5], dtype=np.float), - np.zeros([15, 3, 5], dtype=np.float)), + ([5, 6], [5, 6], [5, 6], [5, 6], lambda x: np.ones(x, dtype=np.float), + lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float), + lambda x: np.ones(x, dtype=np.float)), + ([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.ones(x, dtype=np.float), + lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float), + lambda x: np.ones(x, dtype=np.float)), + ([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.ones(x, dtype=np.float), + lambda x: None, lambda x: np.ones(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float)), + ([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.ones(x, dtype=np.float), + lambda x: np.ones(x, dtype=np.float), lambda x: None, lambda x: None), + ([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.zeros(x, dtype=np.float), + lambda x: None, lambda x: np.ones(x, dtype=np.float), lambda x: None), + ([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.zeros(x, dtype=np.float), + lambda x: np.ones(x, dtype=np.float), lambda x: None, lambda x: np.ones(x, dtype=np.float)), + ([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.array([True], np.bool), + lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float), + lambda x: np.ones(x, dtype=np.float)), + ([15, 3, 5], [15, 3, 5], [15, 1, 5], [15, 3, 5], lambda x: np.array([False], np.bool), + lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float), + lambda x: np.zeros(x, dtype=np.float)), + ([2, 3, 4, 5], [2, 3, 4, 5], [], [2, 3, 4, 5], lambda x: np.ones(x, dtype=np.float), + lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float), + lambda x: np.ones(x, dtype=np.float)), + ([2, 3, 4, 5], [2, 3, 4, 5], [5], [2, 3, 4, 5], lambda x: np.ones(x, dtype=np.float), + lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float), + lambda x: np.ones(x, dtype=np.float)), + ([2, 3, 1, 1], [2, 1, 1, 5], [2, 3, 4, 5], [2, 3, 4, 5], lambda x: np.ones(x, dtype=np.float), + lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float), + lambda x: np.ones(x, dtype=np.float)), + ([2, 3, 4, 1], [2, 1, 1, 5], [1, 3, 1, 5], [2, 3, 4, 5], lambda x: np.ones(x, dtype=np.float), + lambda x: np.zeros(x, dtype=np.float), lambda x: np.ones(x, dtype=np.float), + lambda x: np.ones(x, dtype=np.float)), ]) - def test_select_infer_condition_with_value(self, else_data_shape, than_data_shape, select_output_shape, + def test_select_infer_condition_with_value(self, condition_shape, else_data_shape, than_data_shape, select_output_shape, condition_value, else_value, than_value, output_value): + """ + Unit tests generator can sporadic throw exception if we try + to run generator with call numpy array generation functions. + So we need to use lambda function for escape the problem. + """ + condition_value = condition_value(condition_shape) + else_value = else_value(else_data_shape) + than_value = than_value(than_data_shape) + output_value = output_value(select_output_shape) graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, - update_nodes_attributes=[('condition_data', {'shape': np.array(select_output_shape), + update_nodes_attributes=[('condition_data', {'shape': np.array(condition_shape), 'value': condition_value}), ('else_data', {'shape': np.array(else_data_shape), 'value': else_value}), @@ -188,7 +209,7 @@ class TestSelect(unittest.TestCase): graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[ - ('condition_data', {'shape': np.array(else_data_shape), + ('condition_data', {'shape': np.array(condition_shape), 'value': condition_value}), ('else_data', {'shape': np.array(else_data_shape), 'value': else_value}),