Update ConstantFolding transformation to support Gather with dynamic input (#16973)

* ConstFold Gather op in case of dynamic dims in data input

* Update ConstantFolding transformation to support Gather with dynamic input; add test

* always mark ShapeOf nodes as can_be_folded

* add additional checks for fused_names in the gather test

---------

Co-authored-by: Andrei Kochin <andrei.kochin@intel.com>
This commit is contained in:
Ivan Tikhonov
2023-04-26 13:22:47 +04:00
committed by GitHub
parent ce5f65af14
commit 95ca54d0ab
2 changed files with 49 additions and 1 deletions

View File

@@ -130,7 +130,7 @@ bool ov::pass::ConstantFolding::pre_calculated_values_folding(const std::shared_
// propagation because we can't detect borders of shape_of sub-graphs, so we propagate can_be_folded
// attribute through all nodes including nodes on data path. So to limit the spread of attribute to other
// shape-of sub-graphs we do not propagate it through ShapeOf nodes.
can_be_folded = input_values.begin()->get_partial_shape().is_static();
can_be_folded = true;
} else if (op::util::is_parameter(node) || op::util::is_output(node) || op::util::is_sink(node) ||
is_type<op::util::ReadValueBase>(node)) {
can_be_folded = false;

View File

@@ -12,6 +12,7 @@
#include "ngraph/opsets/opset1.hpp"
#include "ngraph/opsets/opset5.hpp"
#include "ngraph/pass/manager.hpp"
#include "openvino/opsets/opset11.hpp"
#include "util/all_close_f.hpp"
#include "util/test_tools.hpp"
@@ -3584,3 +3585,50 @@ TEST(constant_folding, evaluate_on_tensor_vector) {
ASSERT_EQ(data_shape, result_node->get_output_shape(0));
ASSERT_EQ(add_expected, result_node->cast_vector<int>());
}
TEST(constant_folding, gather_with_dynamic_shapes_in_data_input) {
auto in_0 = std::make_shared<ov::opset11::Parameter>(ov::element::i64, ov::PartialShape{30});
// dynamic input to Gather
auto in_1 = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::PartialShape{-1, 2});
in_1->set_friendly_name("in_1");
auto shape_of = std::make_shared<ov::opset11::ShapeOf>(in_1);
shape_of->set_friendly_name("shape_of");
auto indices = std::make_shared<ov::opset11::Constant>(ov::element::i32, ov::Shape{1}, std::vector<int>{1});
indices->set_friendly_name("indices");
auto axis = std::make_shared<ov::opset11::Constant>(ov::element::i32, ov::Shape{1}, std::vector<int>{0});
axis->set_friendly_name("axis");
auto gather = std::make_shared<ov::opset11::Gather>(shape_of, indices, axis);
gather->set_friendly_name("test");
auto in_2 = std::make_shared<ov::opset11::Constant>(ov::element::i32, ov::Shape{1}, std::vector<int>{10});
in_2->set_friendly_name("in_2");
auto in_3 = std::make_shared<ov::opset11::Constant>(ov::element::i32, ov::Shape{1}, std::vector<int>{1});
in_3->set_friendly_name("in_3");
auto strided_slice = std::make_shared<ov::opset11::StridedSlice>(in_0,
gather,
in_2,
in_3,
std::vector<int64_t>{0, 0},
std::vector<int64_t>{0, 0},
std::vector<int64_t>{0, 0},
std::vector<int64_t>{0, 1});
strided_slice->set_friendly_name("strided_slice");
auto res = std::make_shared<ov::opset11::Result>(strided_slice);
res->set_friendly_name("result");
auto model = std::make_shared<ov::Model>(ov::ResultVector{res}, ov::ParameterVector{in_0, in_1});
run_constant_folding(model);
ASSERT_EQ(count_ops_of_type<ov::opset11::Gather>(model), 0);
ASSERT_EQ(count_ops_of_type<ov::opset11::StridedSlice>(model), 1);
auto new_const = dynamic_pointer_cast<ov::opset11::Constant>(strided_slice->input_value(1).get_node_shared_ptr());
EXPECT_NE(new_const, nullptr);
check_names(new_const, {"shape_of", "indices", "axis", "test"});
// check that we are not copying unnecessary values
check_names(strided_slice, {"strided_slice"}, "strided_slice");
check_names(res, {"result"}, "result");
}