If shape inference - scalar and 1D union handle (#11499)
This commit is contained in:
@@ -34,8 +34,16 @@ static ov::PartialShape resolve_shape(const ov::PartialShape& then_pshape, const
|
||||
|
||||
// if rangs of shapes are not equal or rang of one of them is dynamic function
|
||||
// return shape with dynamic rank
|
||||
if (then_rank.is_dynamic() || else_rank.is_dynamic() || then_rank.get_length() != else_rank.get_length()) {
|
||||
return ov::PartialShape::dynamic(ngraph::Rank::dynamic());
|
||||
if (then_rank.is_dynamic() || else_rank.is_dynamic()) {
|
||||
return ov::PartialShape::dynamic();
|
||||
}
|
||||
if (then_rank.get_length() != else_rank.get_length()) {
|
||||
// Union of scalar and 1D case
|
||||
if (then_rank.get_length() <= 1 && else_rank.get_length() <= 1) {
|
||||
return ov::PartialShape::dynamic(1);
|
||||
} else {
|
||||
return ov::PartialShape::dynamic();
|
||||
}
|
||||
}
|
||||
std::vector<ov::Dimension> new_dims;
|
||||
|
||||
|
||||
@@ -274,4 +274,64 @@ TEST(type_prop, if_dynamic_inputs) {
|
||||
auto res_it = dynamic_shape.begin();
|
||||
EXPECT_EQ(*exp_res_it, *res_it);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, if_scalar_and_1d_union) {
|
||||
// That which we iterate over
|
||||
auto X = make_shared<op::Parameter>(element::f32, Shape{});
|
||||
auto Y = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(1));
|
||||
auto cond = make_shared<op::Parameter>(element::boolean, Shape{});
|
||||
|
||||
// Body parameters
|
||||
auto Xt = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Ye = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
// Body
|
||||
auto then_op = std::make_shared<op::v1::Add>(Xt, Xt);
|
||||
auto then_body_res = make_shared<op::Result>(then_op);
|
||||
auto then_body = make_shared<ngraph::Function>(OutputVector{then_body_res}, ParameterVector{Xt});
|
||||
|
||||
auto else_op = std::make_shared<op::v1::Maximum>(Ye, Ye);
|
||||
auto else_body_res = make_shared<op::Result>(else_op);
|
||||
auto else_body = make_shared<ngraph::Function>(OutputVector{else_body_res}, ParameterVector{Ye});
|
||||
|
||||
auto if_op = make_shared<op::v8::If>(cond);
|
||||
if_op->set_then_body(then_body);
|
||||
if_op->set_else_body(else_body);
|
||||
if_op->set_input(X, Xt, nullptr);
|
||||
if_op->set_input(Y, nullptr, Ye);
|
||||
auto res = if_op->set_output(then_body_res, else_body_res);
|
||||
auto result0 = make_shared<op::Result>(res);
|
||||
PartialShape out_shape{PartialShape::dynamic(1)};
|
||||
auto sh = result0->get_output_partial_shape(0);
|
||||
EXPECT_EQ(sh, out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, if_scalar_and_1d_static_union) {
|
||||
// That which we iterate over
|
||||
auto X = make_shared<op::Parameter>(element::f32, Shape{});
|
||||
auto Y = make_shared<op::Parameter>(element::f32, PartialShape{8});
|
||||
auto cond = make_shared<op::Parameter>(element::boolean, Shape{});
|
||||
|
||||
// Body parameters
|
||||
auto Xt = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Ye = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
// Body
|
||||
auto then_op = std::make_shared<op::v1::Add>(Xt, Xt);
|
||||
auto then_body_res = make_shared<op::Result>(then_op);
|
||||
auto then_body = make_shared<ngraph::Function>(OutputVector{then_body_res}, ParameterVector{Xt});
|
||||
|
||||
auto else_op = std::make_shared<op::v1::Maximum>(Ye, Ye);
|
||||
auto else_body_res = make_shared<op::Result>(else_op);
|
||||
auto else_body = make_shared<ngraph::Function>(OutputVector{else_body_res}, ParameterVector{Ye});
|
||||
|
||||
auto if_op = make_shared<op::v8::If>(cond);
|
||||
if_op->set_then_body(then_body);
|
||||
if_op->set_else_body(else_body);
|
||||
if_op->set_input(X, Xt, nullptr);
|
||||
if_op->set_input(Y, nullptr, Ye);
|
||||
auto res = if_op->set_output(then_body_res, else_body_res);
|
||||
auto result0 = make_shared<op::Result>(res);
|
||||
PartialShape out_shape{PartialShape::dynamic(1)};
|
||||
auto sh = result0->get_output_partial_shape(0);
|
||||
EXPECT_EQ(sh, out_shape);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user