diff --git a/src/core/src/op/if.cpp b/src/core/src/op/if.cpp index 098e79e88ae..2dd2a3f6b26 100644 --- a/src/core/src/op/if.cpp +++ b/src/core/src/op/if.cpp @@ -111,12 +111,12 @@ void ov::op::v8::If::validate_and_infer_types() { val.size() == 1, "The number of values in the If condition constant is greater than 1"); - auto cond_index = val[0] ? THEN_BODY_INDEX : ELSE_BODY_INDEX; - auto body = m_bodies[cond_index]; - auto input_descriptors = m_input_descriptions[cond_index]; - validate_and_infer_type_body(body, input_descriptors); + validate_and_infer_type_body(get_then_body(), m_input_descriptions[THEN_BODY_INDEX]); + validate_and_infer_type_body(get_else_body(), m_input_descriptions[ELSE_BODY_INDEX]); auto output_nodes = outputs(); + auto cond_index = val[0] ? THEN_BODY_INDEX : ELSE_BODY_INDEX; + auto body = m_bodies[cond_index]; // shape and type inference for outputs from If operations for (const auto& output_descr : m_output_descriptions[cond_index]) { auto body_value = body->get_results().at(output_descr->m_body_value_index)->input_value(0); diff --git a/src/core/tests/CMakeLists.txt b/src/core/tests/CMakeLists.txt index cba7fcf6808..12d9778261d 100644 --- a/src/core/tests/CMakeLists.txt +++ b/src/core/tests/CMakeLists.txt @@ -160,6 +160,7 @@ set(SRC type_prop/hsigmoid.cpp type_prop/hswish.cpp type_prop/idft.cpp + type_prop/if.cpp type_prop/interpolate.cpp type_prop/logical_and.cpp type_prop/logical_not.cpp diff --git a/src/core/tests/type_prop/if.cpp b/src/core/tests/type_prop/if.cpp index 80cf9a7f3df..6cd8df2739b 100644 --- a/src/core/tests/type_prop/if.cpp +++ b/src/core/tests/type_prop/if.cpp @@ -25,12 +25,14 @@ TEST(type_prop, if_simple_test) { auto Ye = make_shared(element::f32, PartialShape::dynamic()); // Body auto then_op = std::make_shared(Xt, Yt); - auto then_op_res = std::make_shared(then_op); + auto convert_then_op = std::make_shared(then_op, element::f32); + auto then_op_res = std::make_shared(convert_then_op); auto then_body = make_shared(OutputVector{then_op_res}, ParameterVector{Xt, Yt}); auto else_op = std::make_shared(Xe, Ye); - auto else_op_res = std::make_shared(else_op); + auto convert_else_op = std::make_shared(else_op, element::f32); + auto else_op_res = std::make_shared(convert_else_op); auto else_body = make_shared(OutputVector{else_op_res}, ParameterVector{Xe, Ye}); auto if_op = make_shared(cond); if_op->set_then_body(then_body); @@ -42,6 +44,12 @@ TEST(type_prop, if_simple_test) { Shape out0_shape{32, 40, 10}; auto sh = result0->get_output_shape(0); EXPECT_EQ(sh, out0_shape); + // Check that If validation validates both bodies + convert_then_op->set_convert_element_type(ov::element::f16); + convert_else_op->set_convert_element_type(ov::element::f16); + if_op->validate_and_infer_types(); + EXPECT_EQ(else_op_res->get_element_type(), ov::element::f16); + EXPECT_EQ(then_op_res->get_element_type(), ov::element::f16); } TEST(type_prop, if_non_const_condition_test) {