Optimize Concat operation (#1812)

* 1d case optimization

* code refactor

* concat optimization

* removed using template for concat

* unit tests to concat constant folding

* synchro with current master
This commit is contained in:
Mateusz Bencer
2020-08-18 17:28:57 +02:00
committed by GitHub
parent 8c5262f864
commit a63c8d9537
4 changed files with 158 additions and 85 deletions

View File

@@ -1439,6 +1439,88 @@ TEST(constant_folding, const_concat)
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_concat_3d_single_elem)
{
auto constant_1 = op::Constant::create(element::i32, Shape{1, 1, 1}, vector<int32_t>{1});
auto constant_2 = op::Constant::create(element::i32, Shape{1, 1, 1}, vector<int32_t>{2});
auto concat = make_shared<op::Concat>(NodeVector{constant_1, constant_2}, 0);
auto f = make_shared<Function>(concat, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
ASSERT_TRUE(new_const);
ASSERT_EQ(new_const->get_output_shape(0), (Shape{2, 1, 1}));
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{1, 2};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_concat_axis_2)
{
auto constant_1 =
op::Constant::create(element::i32, Shape{3, 1, 2}, vector<int32_t>{1, 2, 3, 4, 5, 6});
auto constant_2 = op::Constant::create(
element::i32, Shape{3, 1, 4}, vector<int32_t>{7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
auto concat = make_shared<op::Concat>(NodeVector{constant_1, constant_2}, 2);
auto f = make_shared<Function>(concat, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
ASSERT_TRUE(new_const);
ASSERT_EQ(new_const->get_output_shape(0), (Shape{3, 1, 6}));
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{1, 2, 7, 8, 9, 10, 3, 4, 11, 12, 13, 14, 5, 6, 15, 16, 17, 18};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_concat_axis_1_bool_type)
{
auto constant_1 =
op::Constant::create(element::boolean, Shape{1, 1, 2}, vector<int32_t>{true, true});
auto constant_2 = op::Constant::create(
element::boolean, Shape{1, 2, 2}, vector<char>{true, false, true, false});
auto constant_3 = op::Constant::create(
element::boolean, Shape{1, 3, 2}, vector<char>{true, false, true, false, true, false});
auto concat = make_shared<op::Concat>(NodeVector{constant_1, constant_2, constant_3}, 1);
auto f = make_shared<Function>(concat, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Concat>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
ASSERT_TRUE(new_const);
ASSERT_EQ(new_const->get_output_shape(0), (Shape{1, 6, 2}));
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{
true, true, true, false, true, false, true, false, true, false, true, false};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_not)
{
auto constant =