SimplifySecondInputOfReshapeFix (#9809)
This commit is contained in:
parent
29d73ce3c8
commit
fd754ab917
@ -267,6 +267,7 @@ ngraph::pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
|
||||
if (const auto gather = as_type_ptr<op::util::GatherBase>(input.get_node_shared_ptr())) {
|
||||
auto indices_constant = as_type_ptr<opset8::Constant>(gather->get_input_node_shared_ptr(1));
|
||||
if (!indices_constant || !check_shape_of_gather(gather)) {
|
||||
gather_dims_expected_location++;
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -581,3 +581,37 @@ TEST(TransformationTests, SimplifySecondInputOfReshapeTest16) {
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifySecondInputOfReshapeTest17) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
PartialShape data_shape{-1, 256, -1};
|
||||
{
|
||||
auto data_1 = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto data_2 = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_of_1 = std::make_shared<opset7::ShapeOf>(data_1);
|
||||
auto gather_op_1 = gather(shape_of_1, std::vector<int64_t>{0});
|
||||
|
||||
auto constant_1 = opset7::Constant::create(element::i64, Shape{1}, {4});
|
||||
auto constant_2 = opset7::Constant::create(element::i64, Shape{1}, {64});
|
||||
|
||||
auto shape_of_2 = std::make_shared<opset7::ShapeOf>(data_2);
|
||||
auto gather_op_2 = gather(shape_of_2, std::vector<int64_t>{2});
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{gather_op_1, constant_1, constant_2, gather_op_2}, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data_2, concat, true);
|
||||
f = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data_1, data_2});
|
||||
f_ref = ngraph::clone_function(*f);
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::SimplifySecondInputOfReshape>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({-1, 4, 64, -1}));
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user