SimplifySecondInputOfReshapeFix (#9809)

This commit is contained in:
Vladislav Golubev 2022-01-20 22:47:40 +03:00 committed by GitHub
parent 29d73ce3c8
commit fd754ab917
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 0 deletions

View File

@ -267,6 +267,7 @@ ngraph::pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
if (const auto gather = as_type_ptr<op::util::GatherBase>(input.get_node_shared_ptr())) {
auto indices_constant = as_type_ptr<opset8::Constant>(gather->get_input_node_shared_ptr(1));
if (!indices_constant || !check_shape_of_gather(gather)) {
gather_dims_expected_location++;
continue;
}

View File

@ -581,3 +581,37 @@ TEST(TransformationTests, SimplifySecondInputOfReshapeTest16) {
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifySecondInputOfReshapeTest17) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
PartialShape data_shape{-1, 256, -1};
{
auto data_1 = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto data_2 = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto shape_of_1 = std::make_shared<opset7::ShapeOf>(data_1);
auto gather_op_1 = gather(shape_of_1, std::vector<int64_t>{0});
auto constant_1 = opset7::Constant::create(element::i64, Shape{1}, {4});
auto constant_2 = opset7::Constant::create(element::i64, Shape{1}, {64});
auto shape_of_2 = std::make_shared<opset7::ShapeOf>(data_2);
auto gather_op_2 = gather(shape_of_2, std::vector<int64_t>{2});
auto concat = std::make_shared<opset7::Concat>(OutputVector{gather_op_1, constant_1, constant_2, gather_op_2}, 0);
auto reshape = std::make_shared<opset7::Reshape>(data_2, concat, true);
f = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data_1, data_2});
f_ref = ngraph::clone_function(*f);
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::SimplifySecondInputOfReshape>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({-1, 4, 64, -1}));
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}