fix bug on conversion of gather to sequeeze (#19094)
This commit is contained in:
@@ -47,9 +47,11 @@ static bool simplify_gather(shared_ptr<Node> node) {
|
||||
if (!constant_indices)
|
||||
return false;
|
||||
// case_3: if input_shape is (1,3,5,5) and axis = 0, indices = 0, then gather is just a Squeeze
|
||||
const auto constant_indices_size = constant_indices->get_output_shape(0).size();
|
||||
const auto const_indices = constant_indices->cast_vector<int64_t>();
|
||||
if (data.get_shape()[axis] == 1 && const_indices.size() == 1 && const_indices[0] == 0) {
|
||||
auto squeeze = std::make_shared<opset8::Squeeze>(gather->input_value(0), gather->input_value(2));
|
||||
if (data.get_shape()[axis] == 1 && (constant_indices_size == 0 || constant_indices_size == 1) &&
|
||||
const_indices[0] == 0) {
|
||||
auto squeeze = std::make_shared<opset3::Squeeze>(gather->input_value(0), gather->input_value(2));
|
||||
squeeze->set_friendly_name(gather->get_friendly_name());
|
||||
ov::copy_runtime_info(gather, squeeze);
|
||||
ov::replace_node(gather, squeeze);
|
||||
|
||||
@@ -1339,3 +1339,31 @@ TEST(nop_elimination, gather_to_squeeze) {
|
||||
run_and_check(func_axis_2);
|
||||
run_and_check(func_axis_3);
|
||||
}
|
||||
|
||||
TEST(nop_elimination, not_gather_to_squeeze_with_vector_indices) {
|
||||
auto generate_func = [](int64_t gather_axis) {
|
||||
ov::Shape shape{3, 3, 4, 4};
|
||||
shape[gather_axis] = 1;
|
||||
auto arg = std::make_shared<op::Parameter>(element::f32, shape);
|
||||
auto indices = op::Constant::create(element::i64, Shape{1, 1}, vector<int64_t>{0});
|
||||
auto axis = op::Constant::create(element::i64, Shape{}, vector<int64_t>{gather_axis});
|
||||
auto gather = std::make_shared<op::v8::Gather>(arg, indices, axis);
|
||||
return std::make_shared<ov::Model>(NodeVector{gather}, ParameterVector{arg});
|
||||
};
|
||||
|
||||
auto func_axis_0 = generate_func(0);
|
||||
auto func_axis_1 = generate_func(1);
|
||||
auto func_axis_2 = generate_func(2);
|
||||
auto func_axis_3 = generate_func(3);
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<ov::pass::NopElimination>();
|
||||
auto run_and_check = [&](std::shared_ptr<ov::Model>& func) {
|
||||
pass_manager.run_passes(func);
|
||||
EXPECT_EQ(count_ops_of_type<op::v8::Gather>(func), 1);
|
||||
EXPECT_EQ(count_ops_of_type<op::v0::Squeeze>(func), 0);
|
||||
};
|
||||
run_and_check(func_axis_0);
|
||||
run_and_check(func_axis_1);
|
||||
run_and_check(func_axis_2);
|
||||
run_and_check(func_axis_3);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user