diff --git a/src/core/src/op/if.cpp b/src/core/src/op/if.cpp index dc15d9775cb..d5518008fb6 100644 --- a/src/core/src/op/if.cpp +++ b/src/core/src/op/if.cpp @@ -34,8 +34,16 @@ static ov::PartialShape resolve_shape(const ov::PartialShape& then_pshape, const // if rangs of shapes are not equal or rang of one of them is dynamic function // return shape with dynamic rank - if (then_rank.is_dynamic() || else_rank.is_dynamic() || then_rank.get_length() != else_rank.get_length()) { - return ov::PartialShape::dynamic(ngraph::Rank::dynamic()); + if (then_rank.is_dynamic() || else_rank.is_dynamic()) { + return ov::PartialShape::dynamic(); + } + if (then_rank.get_length() != else_rank.get_length()) { + // Union of scalar and 1D case + if (then_rank.get_length() <= 1 && else_rank.get_length() <= 1) { + return ov::PartialShape::dynamic(1); + } else { + return ov::PartialShape::dynamic(); + } } std::vector new_dims; diff --git a/src/core/tests/type_prop/if.cpp b/src/core/tests/type_prop/if.cpp index af68c2246aa..3845cdf2c1e 100644 --- a/src/core/tests/type_prop/if.cpp +++ b/src/core/tests/type_prop/if.cpp @@ -274,4 +274,64 @@ TEST(type_prop, if_dynamic_inputs) { auto res_it = dynamic_shape.begin(); EXPECT_EQ(*exp_res_it, *res_it); } -} \ No newline at end of file +} + +TEST(type_prop, if_scalar_and_1d_union) { + // That which we iterate over + auto X = make_shared(element::f32, Shape{}); + auto Y = make_shared(element::f32, PartialShape::dynamic(1)); + auto cond = make_shared(element::boolean, Shape{}); + + // Body parameters + auto Xt = make_shared(element::f32, PartialShape::dynamic()); + auto Ye = make_shared(element::f32, PartialShape::dynamic()); + // Body + auto then_op = std::make_shared(Xt, Xt); + auto then_body_res = make_shared(then_op); + auto then_body = make_shared(OutputVector{then_body_res}, ParameterVector{Xt}); + + auto else_op = std::make_shared(Ye, Ye); + auto else_body_res = make_shared(else_op); + auto else_body = make_shared(OutputVector{else_body_res}, ParameterVector{Ye}); + + auto if_op = make_shared(cond); + if_op->set_then_body(then_body); + if_op->set_else_body(else_body); + if_op->set_input(X, Xt, nullptr); + if_op->set_input(Y, nullptr, Ye); + auto res = if_op->set_output(then_body_res, else_body_res); + auto result0 = make_shared(res); + PartialShape out_shape{PartialShape::dynamic(1)}; + auto sh = result0->get_output_partial_shape(0); + EXPECT_EQ(sh, out_shape); +} + +TEST(type_prop, if_scalar_and_1d_static_union) { + // That which we iterate over + auto X = make_shared(element::f32, Shape{}); + auto Y = make_shared(element::f32, PartialShape{8}); + auto cond = make_shared(element::boolean, Shape{}); + + // Body parameters + auto Xt = make_shared(element::f32, PartialShape::dynamic()); + auto Ye = make_shared(element::f32, PartialShape::dynamic()); + // Body + auto then_op = std::make_shared(Xt, Xt); + auto then_body_res = make_shared(then_op); + auto then_body = make_shared(OutputVector{then_body_res}, ParameterVector{Xt}); + + auto else_op = std::make_shared(Ye, Ye); + auto else_body_res = make_shared(else_op); + auto else_body = make_shared(OutputVector{else_body_res}, ParameterVector{Ye}); + + auto if_op = make_shared(cond); + if_op->set_then_body(then_body); + if_op->set_else_body(else_body); + if_op->set_input(X, Xt, nullptr); + if_op->set_input(Y, nullptr, Ye); + auto res = if_op->set_output(then_body_res, else_body_res); + auto result0 = make_shared(res); + PartialShape out_shape{PartialShape::dynamic(1)}; + auto sh = result0->get_output_partial_shape(0); + EXPECT_EQ(sh, out_shape); +}