Validate only one If body if condiction is const (#18087)

* Validate only one If body if condiction is const

* Add test with invalid body
This commit is contained in:
Maxim Vafin
2023-06-15 22:04:10 +02:00
committed by GitHub
parent 10ace822ef
commit 55156f9a6c
2 changed files with 52 additions and 7 deletions

View File

@@ -107,8 +107,12 @@ void ov::op::v8::If::validate_and_infer_types() {
val.size() == 1,
"The number of values in the If condition constant is greater than 1");
validate_and_infer_type_body(get_then_body(), m_input_descriptions[THEN_BODY_INDEX]);
validate_and_infer_type_body(get_else_body(), m_input_descriptions[ELSE_BODY_INDEX]);
// Condition is constant we only need to validate one body
if (val[0]) {
validate_and_infer_type_body(get_then_body(), m_input_descriptions[THEN_BODY_INDEX]);
} else {
validate_and_infer_type_body(get_else_body(), m_input_descriptions[ELSE_BODY_INDEX]);
}
auto output_nodes = outputs();
auto cond_index = val[0] ? THEN_BODY_INDEX : ELSE_BODY_INDEX;

View File

@@ -44,12 +44,12 @@ TEST(type_prop, if_simple_test) {
Shape out0_shape{32, 40, 10};
auto sh = result0->get_output_shape(0);
EXPECT_EQ(sh, out0_shape);
// Check that If validation validates both bodies
// Check that If validation when condition is constant validates single body
convert_then_op->set_convert_element_type(ov::element::f16);
convert_else_op->set_convert_element_type(ov::element::f16);
if_op->validate_and_infer_types();
EXPECT_EQ(else_op_res->get_element_type(), ov::element::f16);
EXPECT_EQ(then_op_res->get_element_type(), ov::element::f16);
EXPECT_EQ(else_op_res->get_element_type(), ov::element::f32);
}
TEST(type_prop, if_non_const_condition_test) {
@@ -340,7 +340,7 @@ TEST(type_prop, if_element_type_dynamic) {
// That which we iterate over
auto X = make_shared<op::Parameter>(element::f16, Shape{32, 40, 10});
auto Y = make_shared<op::Parameter>(element::f16, Shape{32, 40, 10});
auto cond = std::make_shared<ngraph::opset5::Constant>(ngraph::element::boolean, ngraph::Shape{1}, true);
auto cond = std::make_shared<ngraph::opset5::Constant>(ngraph::element::boolean, ngraph::Shape{1}, false);
// Set up the cell body, a function from (Xi, Yi) -> (Zo)
// Body parameters
@@ -367,8 +367,49 @@ TEST(type_prop, if_element_type_dynamic) {
Shape out0_shape{32, 40, 10};
auto sh = result0->get_output_shape(0);
EXPECT_EQ(sh, out0_shape);
// Check that If validation validates both bodies
// Check that If validation when condition is constant validates single body
if_op->validate_and_infer_types();
EXPECT_EQ(then_op_res->get_element_type(), ov::element::dynamic);
EXPECT_EQ(else_op_res->get_element_type(), ov::element::f16);
}
TEST(type_prop, if_invalid_false_body) {
// That which we iterate over
auto X = make_shared<op::Parameter>(element::f16, Shape{32, 40, 10});
auto Y = make_shared<op::Parameter>(element::f16, Shape{32, 40});
auto cond = std::make_shared<ngraph::opset5::Constant>(ngraph::element::boolean, ngraph::Shape{1}, false);
// Set up the cell body, a function from (Xi, Yi) -> (Zo)
// Body parameters
auto Xt = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto Yt = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto Xe = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
auto Ye = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
// Body
auto axes_4d = opset5::Constant::create(element::i32, ngraph::Shape{2}, {2, 3});
auto then_reduce_op = std::make_shared<op::v1::ReduceMean>(Xt, Yt);
auto then_op = std::make_shared<op::v1::Add>(then_reduce_op, Yt);
auto then_op_res = std::make_shared<op::Result>(then_op);
auto then_body = make_shared<ngraph::Function>(OutputVector{then_op_res}, ParameterVector{Xt, Yt});
auto axes_3d = opset5::Constant::create(element::i32, ngraph::Shape{1}, {2});
auto else_reduce_op = std::make_shared<op::v1::ReduceMean>(Xe, axes_3d);
auto else_op = std::make_shared<op::v1::Add>(else_reduce_op, Ye);
auto else_op_res = std::make_shared<op::Result>(else_op);
auto else_body = make_shared<ngraph::Function>(OutputVector{else_op_res}, ParameterVector{Xe, Ye});
auto if_op = make_shared<op::v8::If>(cond);
if_op->set_then_body(then_body);
if_op->set_else_body(else_body);
if_op->set_input(X, Xt, Xe);
if_op->set_input(Y, Yt, Ye);
auto res = if_op->set_output(then_op_res, else_op_res);
auto result0 = make_shared<op::Result>(res);
Shape out0_shape{32, 40};
auto sh = result0->get_output_shape(0);
EXPECT_EQ(sh, out0_shape);
// Check that If validation when condition is constant validates single body
if_op->validate_and_infer_types();
EXPECT_EQ(then_op_res->get_element_type(), ov::element::dynamic);
EXPECT_EQ(else_op_res->get_element_type(), ov::element::f16);
EXPECT_EQ(then_op_res->get_element_type(), ov::element::f16);
}