Validate only one If body if condiction is const (#18087)
* Validate only one If body if condiction is const * Add test with invalid body
This commit is contained in:
@@ -107,8 +107,12 @@ void ov::op::v8::If::validate_and_infer_types() {
|
||||
val.size() == 1,
|
||||
"The number of values in the If condition constant is greater than 1");
|
||||
|
||||
validate_and_infer_type_body(get_then_body(), m_input_descriptions[THEN_BODY_INDEX]);
|
||||
validate_and_infer_type_body(get_else_body(), m_input_descriptions[ELSE_BODY_INDEX]);
|
||||
// Condition is constant we only need to validate one body
|
||||
if (val[0]) {
|
||||
validate_and_infer_type_body(get_then_body(), m_input_descriptions[THEN_BODY_INDEX]);
|
||||
} else {
|
||||
validate_and_infer_type_body(get_else_body(), m_input_descriptions[ELSE_BODY_INDEX]);
|
||||
}
|
||||
auto output_nodes = outputs();
|
||||
|
||||
auto cond_index = val[0] ? THEN_BODY_INDEX : ELSE_BODY_INDEX;
|
||||
|
||||
@@ -44,12 +44,12 @@ TEST(type_prop, if_simple_test) {
|
||||
Shape out0_shape{32, 40, 10};
|
||||
auto sh = result0->get_output_shape(0);
|
||||
EXPECT_EQ(sh, out0_shape);
|
||||
// Check that If validation validates both bodies
|
||||
// Check that If validation when condition is constant validates single body
|
||||
convert_then_op->set_convert_element_type(ov::element::f16);
|
||||
convert_else_op->set_convert_element_type(ov::element::f16);
|
||||
if_op->validate_and_infer_types();
|
||||
EXPECT_EQ(else_op_res->get_element_type(), ov::element::f16);
|
||||
EXPECT_EQ(then_op_res->get_element_type(), ov::element::f16);
|
||||
EXPECT_EQ(else_op_res->get_element_type(), ov::element::f32);
|
||||
}
|
||||
|
||||
TEST(type_prop, if_non_const_condition_test) {
|
||||
@@ -340,7 +340,7 @@ TEST(type_prop, if_element_type_dynamic) {
|
||||
// That which we iterate over
|
||||
auto X = make_shared<op::Parameter>(element::f16, Shape{32, 40, 10});
|
||||
auto Y = make_shared<op::Parameter>(element::f16, Shape{32, 40, 10});
|
||||
auto cond = std::make_shared<ngraph::opset5::Constant>(ngraph::element::boolean, ngraph::Shape{1}, true);
|
||||
auto cond = std::make_shared<ngraph::opset5::Constant>(ngraph::element::boolean, ngraph::Shape{1}, false);
|
||||
|
||||
// Set up the cell body, a function from (Xi, Yi) -> (Zo)
|
||||
// Body parameters
|
||||
@@ -367,8 +367,49 @@ TEST(type_prop, if_element_type_dynamic) {
|
||||
Shape out0_shape{32, 40, 10};
|
||||
auto sh = result0->get_output_shape(0);
|
||||
EXPECT_EQ(sh, out0_shape);
|
||||
// Check that If validation validates both bodies
|
||||
// Check that If validation when condition is constant validates single body
|
||||
if_op->validate_and_infer_types();
|
||||
EXPECT_EQ(then_op_res->get_element_type(), ov::element::dynamic);
|
||||
EXPECT_EQ(else_op_res->get_element_type(), ov::element::f16);
|
||||
}
|
||||
|
||||
TEST(type_prop, if_invalid_false_body) {
|
||||
// That which we iterate over
|
||||
auto X = make_shared<op::Parameter>(element::f16, Shape{32, 40, 10});
|
||||
auto Y = make_shared<op::Parameter>(element::f16, Shape{32, 40});
|
||||
auto cond = std::make_shared<ngraph::opset5::Constant>(ngraph::element::boolean, ngraph::Shape{1}, false);
|
||||
|
||||
// Set up the cell body, a function from (Xi, Yi) -> (Zo)
|
||||
// Body parameters
|
||||
auto Xt = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
|
||||
auto Yt = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
|
||||
auto Xe = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
|
||||
auto Ye = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
|
||||
// Body
|
||||
auto axes_4d = opset5::Constant::create(element::i32, ngraph::Shape{2}, {2, 3});
|
||||
auto then_reduce_op = std::make_shared<op::v1::ReduceMean>(Xt, Yt);
|
||||
auto then_op = std::make_shared<op::v1::Add>(then_reduce_op, Yt);
|
||||
auto then_op_res = std::make_shared<op::Result>(then_op);
|
||||
|
||||
auto then_body = make_shared<ngraph::Function>(OutputVector{then_op_res}, ParameterVector{Xt, Yt});
|
||||
|
||||
auto axes_3d = opset5::Constant::create(element::i32, ngraph::Shape{1}, {2});
|
||||
auto else_reduce_op = std::make_shared<op::v1::ReduceMean>(Xe, axes_3d);
|
||||
auto else_op = std::make_shared<op::v1::Add>(else_reduce_op, Ye);
|
||||
auto else_op_res = std::make_shared<op::Result>(else_op);
|
||||
auto else_body = make_shared<ngraph::Function>(OutputVector{else_op_res}, ParameterVector{Xe, Ye});
|
||||
auto if_op = make_shared<op::v8::If>(cond);
|
||||
if_op->set_then_body(then_body);
|
||||
if_op->set_else_body(else_body);
|
||||
if_op->set_input(X, Xt, Xe);
|
||||
if_op->set_input(Y, Yt, Ye);
|
||||
auto res = if_op->set_output(then_op_res, else_op_res);
|
||||
auto result0 = make_shared<op::Result>(res);
|
||||
Shape out0_shape{32, 40};
|
||||
auto sh = result0->get_output_shape(0);
|
||||
EXPECT_EQ(sh, out0_shape);
|
||||
// Check that If validation when condition is constant validates single body
|
||||
if_op->validate_and_infer_types();
|
||||
EXPECT_EQ(then_op_res->get_element_type(), ov::element::dynamic);
|
||||
EXPECT_EQ(else_op_res->get_element_type(), ov::element::f16);
|
||||
EXPECT_EQ(then_op_res->get_element_type(), ov::element::f16);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user