Fix If operation validation function (#9726)

This commit is contained in:
Gleb Kazantaev 2022-01-18 12:41:50 +03:00 committed by GitHub
parent 3a9dba7362
commit 18eaaedb39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 6 deletions

View File

@ -111,12 +111,12 @@ void ov::op::v8::If::validate_and_infer_types() {
val.size() == 1,
"The number of values in the If condition constant is greater than 1");
auto cond_index = val[0] ? THEN_BODY_INDEX : ELSE_BODY_INDEX;
auto body = m_bodies[cond_index];
auto input_descriptors = m_input_descriptions[cond_index];
validate_and_infer_type_body(body, input_descriptors);
validate_and_infer_type_body(get_then_body(), m_input_descriptions[THEN_BODY_INDEX]);
validate_and_infer_type_body(get_else_body(), m_input_descriptions[ELSE_BODY_INDEX]);
auto output_nodes = outputs();
auto cond_index = val[0] ? THEN_BODY_INDEX : ELSE_BODY_INDEX;
auto body = m_bodies[cond_index];
// shape and type inference for outputs from If operations
for (const auto& output_descr : m_output_descriptions[cond_index]) {
auto body_value = body->get_results().at(output_descr->m_body_value_index)->input_value(0);

View File

@ -160,6 +160,7 @@ set(SRC
type_prop/hsigmoid.cpp
type_prop/hswish.cpp
type_prop/idft.cpp
type_prop/if.cpp
type_prop/interpolate.cpp
type_prop/logical_and.cpp
type_prop/logical_not.cpp

View File

@ -25,12 +25,14 @@ TEST(type_prop, if_simple_test) {
auto Ye = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
// Body
auto then_op = std::make_shared<op::v1::Add>(Xt, Yt);
auto then_op_res = std::make_shared<op::Result>(then_op);
auto convert_then_op = std::make_shared<op::v0::Convert>(then_op, element::f32);
auto then_op_res = std::make_shared<op::Result>(convert_then_op);
auto then_body = make_shared<ngraph::Function>(OutputVector{then_op_res}, ParameterVector{Xt, Yt});
auto else_op = std::make_shared<op::v1::Maximum>(Xe, Ye);
auto else_op_res = std::make_shared<op::Result>(else_op);
auto convert_else_op = std::make_shared<op::v0::Convert>(else_op, element::f32);
auto else_op_res = std::make_shared<op::Result>(convert_else_op);
auto else_body = make_shared<ngraph::Function>(OutputVector{else_op_res}, ParameterVector{Xe, Ye});
auto if_op = make_shared<op::v8::If>(cond);
if_op->set_then_body(then_body);
@ -42,6 +44,12 @@ TEST(type_prop, if_simple_test) {
Shape out0_shape{32, 40, 10};
auto sh = result0->get_output_shape(0);
EXPECT_EQ(sh, out0_shape);
// Check that If validation validates both bodies
convert_then_op->set_convert_element_type(ov::element::f16);
convert_else_op->set_convert_element_type(ov::element::f16);
if_op->validate_and_infer_types();
EXPECT_EQ(else_op_res->get_element_type(), ov::element::f16);
EXPECT_EQ(then_op_res->get_element_type(), ov::element::f16);
}
TEST(type_prop, if_non_const_condition_test) {