From fd754ab91766f779b0b698543b31405cc20a7c98 Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Thu, 20 Jan 2022 22:47:40 +0300 Subject: [PATCH] SimplifySecondInputOfReshapeFix (#9809) --- .../simplify_shape_of_sub_graph.cpp | 1 + .../simplify_second_input_of_reshape_test.cpp | 34 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/src/common/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp b/src/common/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp index e3906ec656a..4caa1311136 100644 --- a/src/common/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp @@ -267,6 +267,7 @@ ngraph::pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() { if (const auto gather = as_type_ptr(input.get_node_shared_ptr())) { auto indices_constant = as_type_ptr(gather->get_input_node_shared_ptr(1)); if (!indices_constant || !check_shape_of_gather(gather)) { + gather_dims_expected_location++; continue; } diff --git a/src/tests/functional/inference_engine/transformations/simplify_second_input_of_reshape_test.cpp b/src/tests/functional/inference_engine/transformations/simplify_second_input_of_reshape_test.cpp index c1895819e89..3713352ef11 100644 --- a/src/tests/functional/inference_engine/transformations/simplify_second_input_of_reshape_test.cpp +++ b/src/tests/functional/inference_engine/transformations/simplify_second_input_of_reshape_test.cpp @@ -581,3 +581,37 @@ TEST(TransformationTests, SimplifySecondInputOfReshapeTest16) { auto res = compare_functions(f, f_ref, true); ASSERT_TRUE(res.first) << res.second; } + +TEST(TransformationTests, SimplifySecondInputOfReshapeTest17) { + std::shared_ptr f(nullptr), f_ref(nullptr); + + PartialShape data_shape{-1, 256, -1}; + { + auto data_1 = std::make_shared(element::f32, data_shape); + auto data_2 = std::make_shared(element::f32, data_shape); + + auto shape_of_1 = std::make_shared(data_1); + auto gather_op_1 = gather(shape_of_1, std::vector{0}); + + auto constant_1 = opset7::Constant::create(element::i64, Shape{1}, {4}); + auto constant_2 = opset7::Constant::create(element::i64, Shape{1}, {64}); + + auto shape_of_2 = std::make_shared(data_2); + auto gather_op_2 = gather(shape_of_2, std::vector{2}); + auto concat = std::make_shared(OutputVector{gather_op_1, constant_1, constant_2, gather_op_2}, 0); + + auto reshape = std::make_shared(data_2, concat, true); + f = std::make_shared(NodeVector{reshape}, ParameterVector{data_1, data_2}); + f_ref = ngraph::clone_function(*f); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({-1, 4, 64, -1})); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +}