diff --git a/src/common/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp b/src/common/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp index 14dede1e6f8..27d5d5233bb 100644 --- a/src/common/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp @@ -74,20 +74,31 @@ ngraph::pass::GroupedGatherElimination::GroupedGatherElimination() { while (inputs.size() > i + 1) { auto curr = inputs[i].get_node_shared_ptr(), next = inputs[i + 1].get_node_shared_ptr(); if (curr->get_type_info() != next->get_type_info() || - (!ov::is_type(curr) && !ov::is_type(curr)) || + (!ov::is_type(curr) && !ov::is_type(curr) && !ov::is_type(curr)) || (curr->input_value(0) != next->input_value(0))) { ++i; continue; } // curr and next are the same type of gather which takes data from the same source - bool is_opset1 = ov::is_type(curr); auto joint_indices = ngraph::op::util::make_try_fold(OutputVector{curr->input_value(1), next->input_value(1)}, 0); std::shared_ptr new_gather; - if (is_opset1) + if (ov::is_type(curr)) { new_gather = register_new_node( - curr->input_value(0), joint_indices->output(0), ngraph::opset1::Constant::create(element::i64, {}, {0})->output(0)); - else + curr->input_value(0), + joint_indices->output(0), + ngraph::opset1::Constant::create(element::i64, {}, {0})->output(0)); + } else if (ov::is_type(curr)) { new_gather = register_new_node( - curr->input_value(0), joint_indices->output(0), ngraph::opset1::Constant::create(element::i64, {}, {0})->output(0)); + curr->input_value(0), + joint_indices->output(0), + ngraph::opset1::Constant::create(element::i64, {}, {0})->output(0)); + } else if (ov::is_type(curr)) { + new_gather = register_new_node( + curr->input_value(0), + joint_indices->output(0), + ngraph::opset1::Constant::create(element::i64, {}, {0})->output(0)); + } else { + OPENVINO_UNREACHABLE("Unexpected Gather version"); + } new_ops.push_back(joint_indices); new_ops.push_back(new_gather); inputs.erase(inputs.begin() + i); @@ -239,8 +250,7 @@ ngraph::pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() { auto check_shape_of_gather = [&](const std::shared_ptr& gather) { auto shape_of = gather->get_input_node_shared_ptr(0); - if ((!is_type(shape_of) && !is_type(shape_of)) || - (shape_of->get_output_target_inputs(0).size() > 1)) { + if (!is_type(shape_of) && !is_type(shape_of)) { return false; } return shape_of->input_value(0) == data; diff --git a/src/tests/functional/inference_engine/transformations/simplify_second_input_of_reshape_test.cpp b/src/tests/functional/inference_engine/transformations/simplify_second_input_of_reshape_test.cpp index 99017008459..4bc2faa700d 100644 --- a/src/tests/functional/inference_engine/transformations/simplify_second_input_of_reshape_test.cpp +++ b/src/tests/functional/inference_engine/transformations/simplify_second_input_of_reshape_test.cpp @@ -547,3 +547,37 @@ TEST(TransformationTests, SimplifySecondInputOfReshapeTest15) { auto res = compare_functions(f, f_ref, true); ASSERT_TRUE(res.first) << res.second; } + +TEST(TransformationTests, SimplifySecondInputOfReshapeTest16) { + std::shared_ptr f(nullptr), f_ref(nullptr); + + PartialShape data_shape{ 1, 128, 12, 64 }; + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_of = std::make_shared(data); + auto gather_op_1 = gather(shape_of, std::vector{0}); + auto gather_op_2 = gather(shape_of, std::vector{1}); + auto constant = opset7::Constant::create(element::i64, Shape{ 1 }, { 768 }); + auto concat = std::make_shared(OutputVector{ gather_op_1, gather_op_2, constant }, 0); + + auto reshape = std::make_shared(data, concat, true); + f = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 768 })); + } + { + auto data = std::make_shared(element::f32, data_shape); + auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 3 }, { 0, 0, 768 }); + auto reshape = std::make_shared(data, reshape_pattern, true); + f_ref = std::make_shared(NodeVector{ reshape }, ParameterVector{ data }); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} diff --git a/src/tests/functional/inference_engine/transformations/simplify_shape_of_sub_graph.cpp b/src/tests/functional/inference_engine/transformations/simplify_shape_of_sub_graph.cpp index 43491a78c40..1b365e2f6c3 100644 --- a/src/tests/functional/inference_engine/transformations/simplify_shape_of_sub_graph.cpp +++ b/src/tests/functional/inference_engine/transformations/simplify_shape_of_sub_graph.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -20,7 +21,7 @@ using namespace testing; using namespace ngraph; -auto gather = [](const std::shared_ptr input, std::vector indices, bool scalar = false) -> Output { +auto gatherv7 = [](const std::shared_ptr input, std::vector indices, bool scalar = false) -> Output { std::shared_ptr indices_node; if (scalar) indices_node = opset7::Constant::create(element::i64, {}, indices); @@ -30,18 +31,29 @@ auto gather = [](const std::shared_ptr input, std::vector indices input, indices_node, opset7::Constant::create(element::i64, {}, {0})); }; -TEST_F(TransformationTestsF, ShapeSubGraphTest) { +auto gatherv8 = [](const std::shared_ptr input, std::vector indices, bool scalar = false) -> Output { + std::shared_ptr indices_node; + if (scalar) + indices_node = opset7::Constant::create(element::i64, {}, indices); + else + indices_node = opset7::Constant::create(element::i64, {indices.size()}, indices); + return std::make_shared(input, + indices_node, + opset7::Constant::create(element::i64, {}, {0})); +}; + +TEST_F(TransformationTestsF, ShapeSubGraphTestGatherv7) { Shape data_shape{1, 2, 3, 4}; { auto data = std::make_shared(element::f32, data_shape); auto shape_op_1 = std::make_shared(data); - auto gather_1 = gather(shape_op_1, {1}, true); + auto gather_1 = gatherv7(shape_op_1, {1}, true); auto unsqueeze_1 = std::make_shared( gather_1, opset7::Constant::create(element::i64, {1}, {0})); auto shape_op_2 = std::make_shared(data); - auto gather_2 = gather(shape_op_2, {2}, true); + auto gather_2 = gatherv7(shape_op_2, {2}, true); auto unsqueeze_2 = std::make_shared( gather_2, opset7::Constant::create(element::i64, {1}, {0})); @@ -58,7 +70,7 @@ TEST_F(TransformationTestsF, ShapeSubGraphTest) { auto data = std::make_shared(element::f32, data_shape); auto shape_op_1 = std::make_shared(data); - auto gather_1 = gather(shape_op_1, {1, 2}); + auto gather_1 = gatherv7(shape_op_1, {1, 2}); auto const_1 = opset7::Constant::create(element::i64, Shape{1}, {2}); auto const_2 = opset7::Constant::create(element::i64, Shape{1}, {2}); @@ -70,18 +82,58 @@ TEST_F(TransformationTestsF, ShapeSubGraphTest) { } } -TEST_F(TransformationTestsF, ShapeNopSubGraphTest) { +TEST_F(TransformationTestsF, ShapeSubGraphTestGatherv8) { + Shape data_shape{1, 2, 3, 4}; + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_op_1 = std::make_shared(data); + auto gather_1 = gatherv8(shape_op_1, {1}, true); + auto unsqueeze_1 = + std::make_shared(gather_1, opset7::Constant::create(element::i64, {1}, {0})); + + auto shape_op_2 = std::make_shared(data); + auto gather_2 = gatherv8(shape_op_2, {2}, true); + auto unsqueeze_2 = + std::make_shared(gather_2, opset7::Constant::create(element::i64, {1}, {0})); + + auto const_1 = opset7::Constant::create(element::i64, Shape{1}, {2}); + auto const_2 = opset7::Constant::create(element::i64, Shape{1}, {2}); + + auto concat = std::make_shared(OutputVector{unsqueeze_1, unsqueeze_2, const_1, const_2}, 0); + + auto reshape = std::make_shared(data, concat, false); + function = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + manager.register_pass(); + } + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_op_1 = std::make_shared(data); + auto gather_1 = gatherv8(shape_op_1, {1, 2}); + + auto const_1 = opset7::Constant::create(element::i64, Shape{1}, {2}); + auto const_2 = opset7::Constant::create(element::i64, Shape{1}, {2}); + + auto concat = std::make_shared(OutputVector{gather_1, const_1, const_2}, 0); + + auto reshape = std::make_shared(data, concat, false); + function_ref = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + } +} + +TEST_F(TransformationTestsF, ShapeNopSubGraphTestGatherv7) { PartialShape data_shape{-1, -1}; { auto data = std::make_shared(element::f32, data_shape); auto shape_op_1 = std::make_shared(data); - auto gather_1 = gather(shape_op_1, {0}, true); + auto gather_1 = gatherv7(shape_op_1, {0}, true); auto unsqueeze_1 = std::make_shared( gather_1, opset7::Constant::create(element::i64, {1}, {0})); auto shape_op_2 = std::make_shared(data); - auto gather_2 = gather(shape_op_2, {1}, true); + auto gather_2 = gatherv7(shape_op_2, {1}, true); auto unsqueeze_2 = std::make_shared( gather_2, opset7::Constant::create(element::i64, {1}, {0})); @@ -98,3 +150,51 @@ TEST_F(TransformationTestsF, ShapeNopSubGraphTest) { function_ref = std::make_shared(NodeVector{reshape}, ParameterVector{data}); } } + +TEST_F(TransformationTestsF, ShapeNopSubGraphTestGatherv8) { + PartialShape data_shape{-1, -1}; + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_op_1 = std::make_shared(data); + auto gather_1 = gatherv8(shape_op_1, {0}, true); + auto unsqueeze_1 = + std::make_shared(gather_1, opset7::Constant::create(element::i64, {1}, {0})); + + auto shape_op_2 = std::make_shared(data); + auto gather_2 = gatherv8(shape_op_2, {1}, true); + auto unsqueeze_2 = + std::make_shared(gather_2, opset7::Constant::create(element::i64, {1}, {0})); + + auto concat = std::make_shared(OutputVector{unsqueeze_1, unsqueeze_2}, 0); + + auto reshape = std::make_shared(data, concat, false); + function = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + manager.register_pass(); + } + { + auto data = std::make_shared(element::f32, data_shape); + auto shape_op_1 = std::make_shared(data); + auto reshape = std::make_shared(data, shape_op_1, false); + function_ref = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + } +} + +TEST_F(TransformationTestsF, GroupedGatherEliminationNegative) { + PartialShape data_shape{2, 128}; + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_op = std::make_shared(data); + auto gather = gatherv8(shape_op, {1}, true); + auto unsqueeze = std::make_shared(gather, opset7::Constant::create(element::i64, {1}, {0})); + + auto constant_1 = ngraph::opset7::Constant::create(element::i64, {1}, {0}); + auto constant_2 = ngraph::opset7::Constant::create(element::i64, {1}, {1}); + auto concat = std::make_shared(OutputVector{constant_1, constant_2, unsqueeze}, 0); + + auto reshape = std::make_shared(data, concat, true); + function = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + manager.register_pass(); + } +}