fix bug on conversion of gather to sequeeze (#19094)

This commit is contained in:
Shuangji Yang
2023-08-11 02:25:59 +08:00
committed by GitHub
parent eabf199c3a
commit 5ded6fb699
2 changed files with 32 additions and 2 deletions

View File

@@ -47,9 +47,11 @@ static bool simplify_gather(shared_ptr<Node> node) {
if (!constant_indices)
return false;
// case_3: if input_shape is (1,3,5,5) and axis = 0, indices = 0, then gather is just a Squeeze
const auto constant_indices_size = constant_indices->get_output_shape(0).size();
const auto const_indices = constant_indices->cast_vector<int64_t>();
if (data.get_shape()[axis] == 1 && const_indices.size() == 1 && const_indices[0] == 0) {
auto squeeze = std::make_shared<opset8::Squeeze>(gather->input_value(0), gather->input_value(2));
if (data.get_shape()[axis] == 1 && (constant_indices_size == 0 || constant_indices_size == 1) &&
const_indices[0] == 0) {
auto squeeze = std::make_shared<opset3::Squeeze>(gather->input_value(0), gather->input_value(2));
squeeze->set_friendly_name(gather->get_friendly_name());
ov::copy_runtime_info(gather, squeeze);
ov::replace_node(gather, squeeze);

View File

@@ -1339,3 +1339,31 @@ TEST(nop_elimination, gather_to_squeeze) {
run_and_check(func_axis_2);
run_and_check(func_axis_3);
}
TEST(nop_elimination, not_gather_to_squeeze_with_vector_indices) {
auto generate_func = [](int64_t gather_axis) {
ov::Shape shape{3, 3, 4, 4};
shape[gather_axis] = 1;
auto arg = std::make_shared<op::Parameter>(element::f32, shape);
auto indices = op::Constant::create(element::i64, Shape{1, 1}, vector<int64_t>{0});
auto axis = op::Constant::create(element::i64, Shape{}, vector<int64_t>{gather_axis});
auto gather = std::make_shared<op::v8::Gather>(arg, indices, axis);
return std::make_shared<ov::Model>(NodeVector{gather}, ParameterVector{arg});
};
auto func_axis_0 = generate_func(0);
auto func_axis_1 = generate_func(1);
auto func_axis_2 = generate_func(2);
auto func_axis_3 = generate_func(3);
pass::Manager pass_manager;
pass_manager.register_pass<ov::pass::NopElimination>();
auto run_and_check = [&](std::shared_ptr<ov::Model>& func) {
pass_manager.run_passes(func);
EXPECT_EQ(count_ops_of_type<op::v8::Gather>(func), 1);
EXPECT_EQ(count_ops_of_type<op::v0::Squeeze>(func), 0);
};
run_and_check(func_axis_0);
run_and_check(func_axis_1);
run_and_check(func_axis_2);
run_and_check(func_axis_3);
}