Optimize Concat operation (#1812)
* 1d case optimization * code refactor * concat optimization * removed using template for concat * unit tests to concat constant folding * synchro with current master
This commit is contained in:
@@ -1439,6 +1439,88 @@ TEST(constant_folding, const_concat)
|
||||
ASSERT_EQ(values_expected, values_out);
|
||||
}
|
||||
|
||||
TEST(constant_folding, const_concat_3d_single_elem)
|
||||
{
|
||||
auto constant_1 = op::Constant::create(element::i32, Shape{1, 1, 1}, vector<int32_t>{1});
|
||||
auto constant_2 = op::Constant::create(element::i32, Shape{1, 1, 1}, vector<int32_t>{2});
|
||||
auto concat = make_shared<op::Concat>(NodeVector{constant_1, constant_2}, 0);
|
||||
auto f = make_shared<Function>(concat, ParameterVector{});
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::ConstantFolding>();
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
|
||||
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
|
||||
|
||||
auto new_const =
|
||||
as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
|
||||
|
||||
ASSERT_TRUE(new_const);
|
||||
ASSERT_EQ(new_const->get_output_shape(0), (Shape{2, 1, 1}));
|
||||
|
||||
auto values_out = new_const->get_vector<int32_t>();
|
||||
vector<int32_t> values_expected{1, 2};
|
||||
ASSERT_EQ(values_expected, values_out);
|
||||
}
|
||||
|
||||
TEST(constant_folding, const_concat_axis_2)
|
||||
{
|
||||
auto constant_1 =
|
||||
op::Constant::create(element::i32, Shape{3, 1, 2}, vector<int32_t>{1, 2, 3, 4, 5, 6});
|
||||
auto constant_2 = op::Constant::create(
|
||||
element::i32, Shape{3, 1, 4}, vector<int32_t>{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
|
||||
auto concat = make_shared<op::Concat>(NodeVector{constant_1, constant_2}, 2);
|
||||
auto f = make_shared<Function>(concat, ParameterVector{});
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::ConstantFolding>();
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
|
||||
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
|
||||
|
||||
auto new_const =
|
||||
as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
|
||||
|
||||
ASSERT_TRUE(new_const);
|
||||
ASSERT_EQ(new_const->get_output_shape(0), (Shape{3, 1, 6}));
|
||||
|
||||
auto values_out = new_const->get_vector<int32_t>();
|
||||
vector<int32_t> values_expected{1, 2, 7, 8, 9, 10, 3, 4, 11, 12, 13, 14, 5, 6, 15, 16, 17, 18};
|
||||
ASSERT_EQ(values_expected, values_out);
|
||||
}
|
||||
|
||||
TEST(constant_folding, const_concat_axis_1_bool_type)
|
||||
{
|
||||
auto constant_1 =
|
||||
op::Constant::create(element::boolean, Shape{1, 1, 2}, vector<int32_t>{true, true});
|
||||
auto constant_2 = op::Constant::create(
|
||||
element::boolean, Shape{1, 2, 2}, vector<char>{true, false, true, false});
|
||||
auto constant_3 = op::Constant::create(
|
||||
element::boolean, Shape{1, 3, 2}, vector<char>{true, false, true, false, true, false});
|
||||
auto concat = make_shared<op::Concat>(NodeVector{constant_1, constant_2, constant_3}, 1);
|
||||
auto f = make_shared<Function>(concat, ParameterVector{});
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::ConstantFolding>();
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
|
||||
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
|
||||
|
||||
auto new_const =
|
||||
as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
|
||||
|
||||
ASSERT_TRUE(new_const);
|
||||
ASSERT_EQ(new_const->get_output_shape(0), (Shape{1, 6, 2}));
|
||||
|
||||
auto values_out = new_const->get_vector<char>();
|
||||
vector<char> values_expected{
|
||||
true, true, true, false, true, false, true, false, true, false, true, false};
|
||||
ASSERT_EQ(values_expected, values_out);
|
||||
}
|
||||
|
||||
TEST(constant_folding, const_not)
|
||||
{
|
||||
auto constant =
|
||||
|
||||
Reference in New Issue
Block a user