From 5ded6fb6990dd893d3ef23f921cda5b9a0e3d6f7 Mon Sep 17 00:00:00 2001 From: Shuangji Yang Date: Fri, 11 Aug 2023 02:25:59 +0800 Subject: [PATCH] fix bug on conversion of gather to sequeeze (#19094) --- .../common_optimizations/nop_elimination.cpp | 6 ++-- .../common_optimizations/nop_elimination.cpp | 28 +++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp index 52acd6c4628..63de5e9e702 100644 --- a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp @@ -47,9 +47,11 @@ static bool simplify_gather(shared_ptr node) { if (!constant_indices) return false; // case_3: if input_shape is (1,3,5,5) and axis = 0, indices = 0, then gather is just a Squeeze + const auto constant_indices_size = constant_indices->get_output_shape(0).size(); const auto const_indices = constant_indices->cast_vector(); - if (data.get_shape()[axis] == 1 && const_indices.size() == 1 && const_indices[0] == 0) { - auto squeeze = std::make_shared(gather->input_value(0), gather->input_value(2)); + if (data.get_shape()[axis] == 1 && (constant_indices_size == 0 || constant_indices_size == 1) && + const_indices[0] == 0) { + auto squeeze = std::make_shared(gather->input_value(0), gather->input_value(2)); squeeze->set_friendly_name(gather->get_friendly_name()); ov::copy_runtime_info(gather, squeeze); ov::replace_node(gather, squeeze); diff --git a/src/common/transformations/tests/common_optimizations/nop_elimination.cpp b/src/common/transformations/tests/common_optimizations/nop_elimination.cpp index 0932709116a..3d30bd56ffe 100644 --- a/src/common/transformations/tests/common_optimizations/nop_elimination.cpp +++ b/src/common/transformations/tests/common_optimizations/nop_elimination.cpp @@ -1339,3 +1339,31 @@ TEST(nop_elimination, gather_to_squeeze) { run_and_check(func_axis_2); run_and_check(func_axis_3); } + +TEST(nop_elimination, not_gather_to_squeeze_with_vector_indices) { + auto generate_func = [](int64_t gather_axis) { + ov::Shape shape{3, 3, 4, 4}; + shape[gather_axis] = 1; + auto arg = std::make_shared(element::f32, shape); + auto indices = op::Constant::create(element::i64, Shape{1, 1}, vector{0}); + auto axis = op::Constant::create(element::i64, Shape{}, vector{gather_axis}); + auto gather = std::make_shared(arg, indices, axis); + return std::make_shared(NodeVector{gather}, ParameterVector{arg}); + }; + + auto func_axis_0 = generate_func(0); + auto func_axis_1 = generate_func(1); + auto func_axis_2 = generate_func(2); + auto func_axis_3 = generate_func(3); + pass::Manager pass_manager; + pass_manager.register_pass(); + auto run_and_check = [&](std::shared_ptr& func) { + pass_manager.run_passes(func); + EXPECT_EQ(count_ops_of_type(func), 1); + EXPECT_EQ(count_ops_of_type(func), 0); + }; + run_and_check(func_axis_0); + run_and_check(func_axis_1); + run_and_check(func_axis_2); + run_and_check(func_axis_3); +}