Fix If operation validation function (#9726)
This commit is contained in:
parent
3a9dba7362
commit
18eaaedb39
@ -111,12 +111,12 @@ void ov::op::v8::If::validate_and_infer_types() {
|
||||
val.size() == 1,
|
||||
"The number of values in the If condition constant is greater than 1");
|
||||
|
||||
auto cond_index = val[0] ? THEN_BODY_INDEX : ELSE_BODY_INDEX;
|
||||
auto body = m_bodies[cond_index];
|
||||
auto input_descriptors = m_input_descriptions[cond_index];
|
||||
validate_and_infer_type_body(body, input_descriptors);
|
||||
validate_and_infer_type_body(get_then_body(), m_input_descriptions[THEN_BODY_INDEX]);
|
||||
validate_and_infer_type_body(get_else_body(), m_input_descriptions[ELSE_BODY_INDEX]);
|
||||
auto output_nodes = outputs();
|
||||
|
||||
auto cond_index = val[0] ? THEN_BODY_INDEX : ELSE_BODY_INDEX;
|
||||
auto body = m_bodies[cond_index];
|
||||
// shape and type inference for outputs from If operations
|
||||
for (const auto& output_descr : m_output_descriptions[cond_index]) {
|
||||
auto body_value = body->get_results().at(output_descr->m_body_value_index)->input_value(0);
|
||||
|
@ -160,6 +160,7 @@ set(SRC
|
||||
type_prop/hsigmoid.cpp
|
||||
type_prop/hswish.cpp
|
||||
type_prop/idft.cpp
|
||||
type_prop/if.cpp
|
||||
type_prop/interpolate.cpp
|
||||
type_prop/logical_and.cpp
|
||||
type_prop/logical_not.cpp
|
||||
|
@ -25,12 +25,14 @@ TEST(type_prop, if_simple_test) {
|
||||
auto Ye = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
// Body
|
||||
auto then_op = std::make_shared<op::v1::Add>(Xt, Yt);
|
||||
auto then_op_res = std::make_shared<op::Result>(then_op);
|
||||
auto convert_then_op = std::make_shared<op::v0::Convert>(then_op, element::f32);
|
||||
auto then_op_res = std::make_shared<op::Result>(convert_then_op);
|
||||
|
||||
auto then_body = make_shared<ngraph::Function>(OutputVector{then_op_res}, ParameterVector{Xt, Yt});
|
||||
|
||||
auto else_op = std::make_shared<op::v1::Maximum>(Xe, Ye);
|
||||
auto else_op_res = std::make_shared<op::Result>(else_op);
|
||||
auto convert_else_op = std::make_shared<op::v0::Convert>(else_op, element::f32);
|
||||
auto else_op_res = std::make_shared<op::Result>(convert_else_op);
|
||||
auto else_body = make_shared<ngraph::Function>(OutputVector{else_op_res}, ParameterVector{Xe, Ye});
|
||||
auto if_op = make_shared<op::v8::If>(cond);
|
||||
if_op->set_then_body(then_body);
|
||||
@ -42,6 +44,12 @@ TEST(type_prop, if_simple_test) {
|
||||
Shape out0_shape{32, 40, 10};
|
||||
auto sh = result0->get_output_shape(0);
|
||||
EXPECT_EQ(sh, out0_shape);
|
||||
// Check that If validation validates both bodies
|
||||
convert_then_op->set_convert_element_type(ov::element::f16);
|
||||
convert_else_op->set_convert_element_type(ov::element::f16);
|
||||
if_op->validate_and_infer_types();
|
||||
EXPECT_EQ(else_op_res->get_element_type(), ov::element::f16);
|
||||
EXPECT_EQ(then_op_res->get_element_type(), ov::element::f16);
|
||||
}
|
||||
|
||||
TEST(type_prop, if_non_const_condition_test) {
|
||||
|
Loading…
Reference in New Issue
Block a user