diff --git a/inference-engine/src/transformations/include/transformations/common_optimizations/simplify_shape_of_sub_graph.hpp b/inference-engine/src/transformations/include/transformations/common_optimizations/simplify_shape_of_sub_graph.hpp index 2c7c7e5cd0d..6ede4b4e329 100644 --- a/inference-engine/src/transformations/include/transformations/common_optimizations/simplify_shape_of_sub_graph.hpp +++ b/inference-engine/src/transformations/include/transformations/common_optimizations/simplify_shape_of_sub_graph.hpp @@ -22,6 +22,7 @@ class TRANSFORMATIONS_API SharedShapeOf; class TRANSFORMATIONS_API GroupedGatherElimination; class TRANSFORMATIONS_API GatherNopElimination; class TRANSFORMATIONS_API SimplifyGatherShapeOf; +class TRANSFORMATIONS_API SimplifySecondInputOfReshape; } // namespace pass } // namespace ngraph @@ -82,3 +83,14 @@ public: NGRAPH_RTTI_DECLARATION; SimplifyGatherShapeOf(); }; + +/** + * @ingroup ie_transformation_common_api + * @brief SimplifySecondInputOfReshape optimizes `shapeof->gather` into zero values for + * reshape pattern values if possible. + */ +class ngraph::pass::SimplifySecondInputOfReshape : public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + SimplifySecondInputOfReshape(); +}; diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp index 7e11215d9d0..b7affe594a9 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -191,6 +192,102 @@ ngraph::pass::SimplifyGatherShapeOf::SimplifyGatherShapeOf() { this->register_matcher(m, callback); } +NGRAPH_RTTI_DEFINITION(ngraph::pass::SimplifySecondInputOfReshape, "SimplifySecondInputOfReshape", 0); + +ngraph::pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() { + MATCHER_SCOPE(SimplifySecondInputOfReshape); + const auto input = pattern::any_input(); + auto has_static_1d_shape = [](const Output& output) { + return pattern::has_static_shape()(output) && pattern::rank_equals(1)(output); + }; + const auto concat = pattern::wrap_type(has_static_1d_shape); + const auto reshape_pattern = pattern::wrap_type({ input, concat }); + + ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) { + auto node = m.get_match_root(); + const auto reshape = as_type_ptr(node); + if (!reshape || reshape->get_special_zero() == false) { + return false; + } + + const auto concat = as_type_ptr(reshape->get_input_node_shared_ptr(1)); + if (!concat) + return false; + + const auto concat_axis = concat->get_axis(); + OPENVINO_ASSERT(concat_axis == 0, "axis is not valid for matched Concat with 1D output"); + + auto data = m.get_pattern_value_map().at(input); + if (is_type(data.get_node_shared_ptr()) || + ngraph::op::is_unary_elementwise_arithmetic(data.get_node_shared_ptr())) { + data = data.get_node_shared_ptr()->input_value(0); + } + + auto check_shape_of_gather = [&](const std::shared_ptr& gather) { + auto shape_of = gather->get_input_node_shared_ptr(0); + if ((!is_type(shape_of) && !is_type(shape_of)) || + (shape_of->get_output_target_inputs(0).size() > 1)) { + return false; + } + return shape_of->input_value(0) == data; + }; + + const auto concat_inputs = concat->input_values(); + OutputVector new_concat_inputs = concat_inputs; + std::int64_t gather_dims_expected_location = 0; + bool gather_folded = false; + + // We need this check to avoid sequences shapeOf -> gather -> concat + // that change the arrangement of dimensions in the reshape pattern + for (auto& input : new_concat_inputs) { + if (const auto gather = as_type_ptr(input.get_node_shared_ptr())) { + auto indices_constant = as_type_ptr(gather->get_input_node_shared_ptr(1)); + if (!indices_constant || !check_shape_of_gather(gather)) { + continue; + } + + bool gather_can_be_fused = true; + const auto indices = indices_constant->cast_vector(); + for (size_t i = 0; i < indices.size(); ++i) { + if (indices[i] != gather_dims_expected_location) { + gather_can_be_fused = false; + } + gather_dims_expected_location++; + } + + if (gather_can_be_fused) { + const size_t num_of_unchanged_dimensions = indices.size(); + const auto subgraph_et = gather->get_input_element_type(0); + input = opset8::Constant::create(subgraph_et, Shape{ num_of_unchanged_dimensions }, { 0 }); + gather_folded = true; + } + } else { + const auto concat_input_shape = input.get_shape(); + OPENVINO_ASSERT(concat_input_shape.size() == 1, "concat input rank is not valid for matched Concat with 1D output"); + gather_dims_expected_location += concat_input_shape[0]; + } + } + + if (!gather_folded) { + return false; + } + + const auto new_concat = op::util::make_try_fold(new_concat_inputs, concat_axis); + new_concat->set_friendly_name(concat->get_friendly_name()); + copy_runtime_info(concat, new_concat); + + const auto new_reshape = reshape->clone_with_new_inputs({ reshape->input_value(0), new_concat }); + new_reshape->set_friendly_name(reshape->get_friendly_name()); + + copy_runtime_info(reshape, new_reshape); + replace_node(reshape, new_reshape); + return true; + }; + + auto m = std::make_shared(reshape_pattern, matcher_name); + this->register_matcher(m, callback); +} + NGRAPH_RTTI_DEFINITION(ngraph::pass::SimplifyShapeOfSubGraph, "SimplifyShapeOfSubGraph", 0); bool ngraph::pass::SimplifyShapeOfSubGraph::run_on_function(std::shared_ptr f) { @@ -201,6 +298,7 @@ bool ngraph::pass::SimplifyShapeOfSubGraph::run_on_function(std::shared_ptr(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); manager.run_passes(f); return false; } diff --git a/inference-engine/tests/functional/inference_engine/transformations/simplify_second_input_of_reshape_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/simplify_second_input_of_reshape_test.cpp new file mode 100644 index 00000000000..452c9552eb5 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/transformations/simplify_second_input_of_reshape_test.cpp @@ -0,0 +1,548 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include + +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + + +using namespace testing; +using namespace ngraph; + +auto gather = [](const std::shared_ptr input, std::vector indices) -> Output { + std::shared_ptr indices_node = opset7::Constant::create(element::i64, {indices.size()}, indices); + std::shared_ptr axis_node = opset7::Constant::create(element::i64, {}, { 0 }); + return std::make_shared(input, indices_node, axis_node); +}; + +auto fake_quantize = [](const std::shared_ptr input) -> Output { + auto il = opset7::Constant::create(element::f32, Shape{}, { 0.f }); + auto ih = opset7::Constant::create(element::f32, Shape{}, { 25.5f }); + auto ol = opset7::Constant::create(element::f32, Shape{}, { 0.f }); + auto oh = opset7::Constant::create(element::f32, Shape{}, { 25.5f }); + return std::make_shared(input, il, ih, ol, oh, 256); +}; + +TEST(TransformationTests, SimplifySecondInputOfReshapeTest1) { + std::shared_ptr f(nullptr), f_ref(nullptr); + + PartialShape data_shape{1, 128, 12, 64}; + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_of = std::make_shared(data); + auto gather_op = gather(shape_of, std::vector{0, 1}); + auto constant = opset7::Constant::create(element::i64, Shape{1}, {768}); + auto concat = std::make_shared(OutputVector{ gather_op, constant }, 0); + + auto reshape = std::make_shared(data, concat, true); + f = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 768 })); + } + { + auto data = std::make_shared(element::f32, data_shape); + auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 3 }, { 0, 0, 768 }); + auto reshape = std::make_shared(data, reshape_pattern, true); + f_ref = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, SimplifySecondInputOfReshapeTest2) { + std::shared_ptr f(nullptr), f_ref(nullptr); + + PartialShape data_shape{ 1, 128, 12, 64 }; + { + auto data = std::make_shared(element::f32, data_shape); + auto fq = fake_quantize(data); + + auto shape_of = std::make_shared(data); + auto gather_op = gather(shape_of, std::vector{0, 1}); + auto constant = opset7::Constant::create(element::i64, Shape{ 1 }, { 768 }); + auto concat = std::make_shared(OutputVector{ gather_op, constant }, 0); + + auto reshape = std::make_shared(fq, concat, true); + f = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 768 })); + } + { + auto data = std::make_shared(element::f32, data_shape); + auto fq = fake_quantize(data); + auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 3 }, { 0, 0, 768 }); + auto reshape = std::make_shared(fq, reshape_pattern, true); + f_ref = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, SimplifySecondInputOfReshapeTest3) { + std::shared_ptr f(nullptr), f_ref(nullptr); + + PartialShape data_shape{ 1, 128, 768 }; + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_of = std::make_shared(data); + auto gather_op = gather(shape_of, std::vector{0, 1}); + auto constant_1 = opset7::Constant::create(element::i64, Shape{ 1 }, { 12 }); + auto constant_2 = opset7::Constant::create(element::i64, Shape{ 1 }, { 64 }); + auto concat = std::make_shared(OutputVector{ gather_op, constant_1, constant_2 }, 0); + + auto reshape = std::make_shared(data, concat, true); + f = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 12, 64 })); + } + { + auto data = std::make_shared(element::f32, data_shape); + auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 4 }, { 0, 0, 12, 64 }); + auto reshape = std::make_shared(data, reshape_pattern, true); + f_ref = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, SimplifySecondInputOfReshapeTest4) { + std::shared_ptr f(nullptr), f_ref(nullptr); + + PartialShape data_shape{ 1, 128, 768 }; + { + auto data = std::make_shared(element::f32, data_shape); + auto fq = fake_quantize(data); + + auto shape_of = std::make_shared(data); + auto gather_op = gather(shape_of, std::vector{0, 1}); + auto constant_1 = opset7::Constant::create(element::i64, Shape{ 1 }, { 12 }); + auto constant_2 = opset7::Constant::create(element::i64, Shape{ 1 }, { 64 }); + auto concat = std::make_shared(OutputVector{ gather_op, constant_1, constant_2 }, 0); + + auto reshape = std::make_shared(fq, concat, true); + f = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 12, 64 })); + } + { + auto data = std::make_shared(element::f32, data_shape); + auto fq = fake_quantize(data); + auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 4 }, { 0, 0, 12, 64 }); + auto reshape = std::make_shared(fq, reshape_pattern, true); + f_ref = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, SimplifySecondInputOfReshapeTest5) { + std::shared_ptr f(nullptr), f_ref(nullptr); + + PartialShape data_shape = PartialShape::dynamic(3); + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_of = std::make_shared(data); + auto gather_op = gather(shape_of, std::vector{0, 1}); + auto constant = opset7::Constant::create(element::i64, Shape{ 1 }, { -1 }); + auto concat = std::make_shared(OutputVector{ gather_op, constant }, 0); + + auto reshape = std::make_shared(data, concat, true); + f = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape::dynamic(3)); + } + { + auto data = std::make_shared(element::f32, data_shape); + auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 3 }, { 0, 0, -1 }); + auto reshape = std::make_shared(data, reshape_pattern, true); + f_ref = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, SimplifySecondInputOfReshapeTest6) { + std::shared_ptr f(nullptr), f_ref(nullptr); + + PartialShape data_shape = PartialShape::dynamic(); + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_of = std::make_shared(data); + auto gather_op = gather(shape_of, std::vector{0, 1}); + auto constant = opset7::Constant::create(element::i64, Shape{ 1 }, { -1 }); + auto concat = std::make_shared(OutputVector{ gather_op, constant }, 0); + + auto reshape = std::make_shared(data, concat, true); + f = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape::dynamic(3)); + } + { + auto data = std::make_shared(element::f32, data_shape); + auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 3 }, { 0, 0, -1 }); + auto reshape = std::make_shared(data, reshape_pattern, true); + f_ref = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, SimplifySecondInputOfReshapeTest7) { + std::shared_ptr f(nullptr), f_ref(nullptr); + + PartialShape data_shape{ 1, 128, 12, 64 }; + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_of = std::make_shared(data); + auto gather_op = gather(shape_of, std::vector{2, 3}); + auto constant_1 = opset7::Constant::create(element::i64, Shape{ 1 }, { 64 }); + auto constant_2 = opset7::Constant::create(element::i64, Shape{ 1 }, { 2 }); + auto concat = std::make_shared(OutputVector{ constant_1, constant_2, gather_op }, 0); + + auto reshape = std::make_shared(data, concat, true); + f = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 64, 2, 12, 64 })); + } + { + auto data = std::make_shared(element::f32, data_shape); + auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 4 }, { 64, 2, 0, 0 }); + auto reshape = std::make_shared(data, reshape_pattern, true); + f_ref = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, SimplifySecondInputOfReshapeTest8) { + std::shared_ptr f(nullptr), f_ref(nullptr); + + PartialShape data_shape{ 1, 128, 12, 64 }; + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_of = std::make_shared(data); + auto gather_op = gather(shape_of, std::vector{2}); + auto constant_1 = opset7::Constant::create(element::i64, Shape{ 1 }, { 64 }); + auto constant_2 = opset7::Constant::create(element::i64, Shape{ 1 }, { 2 }); + auto constant_3 = opset7::Constant::create(element::i64, Shape{ 1 }, { 64 }); + auto concat = std::make_shared(OutputVector{ constant_1, constant_2, gather_op, constant_3 }, 0); + + auto reshape = std::make_shared(data, concat, true); + f = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 64, 2, 12, 64 })); + } + { + auto data = std::make_shared(element::f32, data_shape); + auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 4 }, { 64, 2, 0, 64 }); + auto reshape = std::make_shared(data, reshape_pattern, true); + f_ref = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, SimplifySecondInputOfReshapeTest9) { + std::shared_ptr f(nullptr), f_ref(nullptr); + + PartialShape data_shape{ 1, 128, 12, 64 }; + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_of = std::make_shared(data); + auto gather_op = gather(shape_of, std::vector{0, 2}); + auto constant = opset7::Constant::create(element::i64, Shape{ 1 }, { -1 }); + auto concat = std::make_shared(OutputVector{ gather_op, constant }, 0); + + auto reshape = std::make_shared(data, concat, true); + f = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + f_ref = f; + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 12, 8192 })); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, SimplifySecondInputOfReshapeTest10) { + std::shared_ptr f(nullptr), f_ref(nullptr); + + PartialShape data_shape{ 1, 128, 12, 64 }; + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_of_1 = std::make_shared(data); + auto shape_of_2 = std::make_shared(data); + auto gather_op_1 = gather(shape_of_1, std::vector{0, 1}); + auto gather_op_2 = gather(shape_of_2, std::vector{3}); + auto gather_op_3 = gather(shape_of_2, std::vector{2}); + auto concat = std::make_shared(OutputVector{ gather_op_1, gather_op_2, gather_op_3 }, 0); + + auto reshape = std::make_shared(data, concat, true); + f = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 64, 12 })); + } + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_of = std::make_shared(data); + auto constant = opset7::Constant::create(element::i64, Shape{ 2 }, { 0, 0 }); + auto gather_op_2 = gather(shape_of, std::vector{3}); + auto gather_op_3 = gather(shape_of, std::vector{2}); + auto concat = std::make_shared(OutputVector{ constant, gather_op_2, gather_op_3 }, 0); + + auto reshape = std::make_shared(data, concat, true); + f_ref = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, SimplifySecondInputOfReshapeTest11) { + std::shared_ptr f(nullptr), f_ref(nullptr); + + PartialShape data_shape{ 1, 128, 12, 64 }; + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_of_1 = std::make_shared(data); + auto shape_of_2 = std::make_shared(data); + auto concat_input_0 = gather(shape_of_1, std::vector{0}); + auto concat_input_1 = ngraph::opset7::Constant::create(ngraph::element::i64, {1}, { 64 }); + auto concat_input_2 = gather(shape_of_2, std::vector{2}); + auto concat_input_3 = ngraph::opset7::Constant::create(ngraph::element::i64, {1}, { 128 }); + auto concat = std::make_shared(OutputVector{ concat_input_0, concat_input_1, concat_input_2, concat_input_3 }, 0); + + auto reshape = std::make_shared(data, concat, true); + f = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 64, 12, 128 })); + } + { + auto data = std::make_shared(element::f32, data_shape); + auto constant = opset7::Constant::create(element::i64, Shape{ 4 }, { 0, 64, 0, 128 }); + auto reshape = std::make_shared(data, constant, true); + f_ref = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, SimplifySecondInputOfReshapeTest12) { + std::shared_ptr f(nullptr), f_ref(nullptr); + + PartialShape data_shape{ 1, 128, 768 }; + { + auto data = std::make_shared(element::f32, data_shape); + auto gelu = std::make_shared(data); + + auto shape_of = std::make_shared(data); + auto gather_op = gather(shape_of, std::vector{0, 1}); + auto constant_1 = opset7::Constant::create(element::i64, Shape{ 1 }, { 12 }); + auto constant_2 = opset7::Constant::create(element::i64, Shape{ 1 }, { 64 }); + auto concat = std::make_shared(OutputVector{ gather_op, constant_1, constant_2 }, 0); + + auto reshape = std::make_shared(gelu, concat, true); + f = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 12, 64 })); + } + { + auto data = std::make_shared(element::f32, data_shape); + auto gelu = std::make_shared(data); + auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 4 }, { 0, 0, 12, 64 }); + auto reshape = std::make_shared(gelu, reshape_pattern, true); + f_ref = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, SimplifySecondInputOfReshapeTest13) { + std::shared_ptr f(nullptr), f_ref(nullptr); + + PartialShape data_shape{ 1, 128, 12, 64 }; + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_of = std::make_shared(data, element::i32); + auto gather_op = gather(shape_of, std::vector{0, 1}); + auto constant = opset7::Constant::create(element::i32, Shape{ 1 }, { 768 }); + auto concat = std::make_shared(OutputVector{ gather_op, constant }, 0); + + auto reshape = std::make_shared(data, concat, true); + f = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 768 })); + } + { + auto data = std::make_shared(element::f32, data_shape); + auto reshape_pattern = opset7::Constant::create(element::i32, Shape{ 3 }, { 0, 0, 768 }); + auto reshape = std::make_shared(data, reshape_pattern, true); + f_ref = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, SimplifySecondInputOfReshapeTest14) { + std::shared_ptr f(nullptr), f_ref(nullptr); + + PartialShape data_shape{ 1, 128, 12, 64 }; + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_of = std::make_shared(data); + auto gather_op = gather(shape_of, std::vector{0, 1}); + auto constant = opset7::Constant::create(element::i64, Shape{ 1 }, { 768 }); + auto concat = std::make_shared(OutputVector{ gather_op, constant }, 0); + + auto reshape = std::make_shared(data, concat, true); + f = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 768 })); + } + { + auto data = std::make_shared(element::f32, data_shape); + auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 3 }, { 0, 0, 768 }); + auto reshape = std::make_shared(data, reshape_pattern, true); + f_ref = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, SimplifySecondInputOfReshapeTest15) { + std::shared_ptr f(nullptr), f_ref(nullptr); + + PartialShape data_shape{ 1, 128, 768 }; + { + auto data = std::make_shared(element::f32, data_shape); + auto gelu = std::make_shared(data); + + auto shape_of = std::make_shared(data); + auto gather_op = gather(shape_of, std::vector{0, 1}); + auto constant = opset7::Constant::create(element::i64, Shape{ 2 }, { 12, 64 }); + auto concat = std::make_shared(OutputVector{ gather_op, constant }, 0); + + auto reshape = std::make_shared(gelu, concat, true); + f = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 12, 64 })); + } + { + auto data = std::make_shared(element::f32, data_shape); + auto gelu = std::make_shared(data); + auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 4 }, { 0, 0, 12, 64 }); + auto reshape = std::make_shared(gelu, reshape_pattern, true); + f_ref = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +}