If shape inference - scalar and 1D union handle (#11499)

This commit is contained in:
Mateusz Bencer
2022-05-06 23:55:36 +02:00
committed by GitHub
parent e0e916b557
commit d60deae083
2 changed files with 71 additions and 3 deletions

View File

@@ -34,8 +34,16 @@ static ov::PartialShape resolve_shape(const ov::PartialShape& then_pshape, const
// if rangs of shapes are not equal or rang of one of them is dynamic function
// return shape with dynamic rank
if (then_rank.is_dynamic() || else_rank.is_dynamic() || then_rank.get_length() != else_rank.get_length()) {
return ov::PartialShape::dynamic(ngraph::Rank::dynamic());
if (then_rank.is_dynamic() || else_rank.is_dynamic()) {
return ov::PartialShape::dynamic();
}
if (then_rank.get_length() != else_rank.get_length()) {
// Union of scalar and 1D case
if (then_rank.get_length() <= 1 && else_rank.get_length() <= 1) {
return ov::PartialShape::dynamic(1);
} else {
return ov::PartialShape::dynamic();
}
}
std::vector<ov::Dimension> new_dims;

View File

@@ -274,4 +274,64 @@ TEST(type_prop, if_dynamic_inputs) {
auto res_it = dynamic_shape.begin();
EXPECT_EQ(*exp_res_it, *res_it);
}
}
}
TEST(type_prop, if_scalar_and_1d_union) {
// That which we iterate over
auto X = make_shared<op::Parameter>(element::f32, Shape{});
auto Y = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(1));
auto cond = make_shared<op::Parameter>(element::boolean, Shape{});
// Body parameters
auto Xt = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto Ye = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
// Body
auto then_op = std::make_shared<op::v1::Add>(Xt, Xt);
auto then_body_res = make_shared<op::Result>(then_op);
auto then_body = make_shared<ngraph::Function>(OutputVector{then_body_res}, ParameterVector{Xt});
auto else_op = std::make_shared<op::v1::Maximum>(Ye, Ye);
auto else_body_res = make_shared<op::Result>(else_op);
auto else_body = make_shared<ngraph::Function>(OutputVector{else_body_res}, ParameterVector{Ye});
auto if_op = make_shared<op::v8::If>(cond);
if_op->set_then_body(then_body);
if_op->set_else_body(else_body);
if_op->set_input(X, Xt, nullptr);
if_op->set_input(Y, nullptr, Ye);
auto res = if_op->set_output(then_body_res, else_body_res);
auto result0 = make_shared<op::Result>(res);
PartialShape out_shape{PartialShape::dynamic(1)};
auto sh = result0->get_output_partial_shape(0);
EXPECT_EQ(sh, out_shape);
}
TEST(type_prop, if_scalar_and_1d_static_union) {
// That which we iterate over
auto X = make_shared<op::Parameter>(element::f32, Shape{});
auto Y = make_shared<op::Parameter>(element::f32, PartialShape{8});
auto cond = make_shared<op::Parameter>(element::boolean, Shape{});
// Body parameters
auto Xt = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto Ye = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
// Body
auto then_op = std::make_shared<op::v1::Add>(Xt, Xt);
auto then_body_res = make_shared<op::Result>(then_op);
auto then_body = make_shared<ngraph::Function>(OutputVector{then_body_res}, ParameterVector{Xt});
auto else_op = std::make_shared<op::v1::Maximum>(Ye, Ye);
auto else_body_res = make_shared<op::Result>(else_op);
auto else_body = make_shared<ngraph::Function>(OutputVector{else_body_res}, ParameterVector{Ye});
auto if_op = make_shared<op::v8::If>(cond);
if_op->set_then_body(then_body);
if_op->set_else_body(else_body);
if_op->set_input(X, Xt, nullptr);
if_op->set_input(Y, nullptr, Ye);
auto res = if_op->set_output(then_body_res, else_body_res);
auto result0 = make_shared<op::Result>(res);
PartialShape out_shape{PartialShape::dynamic(1)};
auto sh = result0->get_output_partial_shape(0);
EXPECT_EQ(sh, out_shape);
}