From 55156f9a6ceb59cd1066483832e1ba7fb55cf492 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Thu, 15 Jun 2023 22:04:10 +0200 Subject: [PATCH] Validate only one If body if condiction is const (#18087) * Validate only one If body if condiction is const * Add test with invalid body --- src/core/src/op/if.cpp | 8 ++++-- src/core/tests/type_prop/if.cpp | 51 +++++++++++++++++++++++++++++---- 2 files changed, 52 insertions(+), 7 deletions(-) diff --git a/src/core/src/op/if.cpp b/src/core/src/op/if.cpp index cd008e419e3..eaffd049424 100644 --- a/src/core/src/op/if.cpp +++ b/src/core/src/op/if.cpp @@ -107,8 +107,12 @@ void ov::op::v8::If::validate_and_infer_types() { val.size() == 1, "The number of values in the If condition constant is greater than 1"); - validate_and_infer_type_body(get_then_body(), m_input_descriptions[THEN_BODY_INDEX]); - validate_and_infer_type_body(get_else_body(), m_input_descriptions[ELSE_BODY_INDEX]); + // Condition is constant we only need to validate one body + if (val[0]) { + validate_and_infer_type_body(get_then_body(), m_input_descriptions[THEN_BODY_INDEX]); + } else { + validate_and_infer_type_body(get_else_body(), m_input_descriptions[ELSE_BODY_INDEX]); + } auto output_nodes = outputs(); auto cond_index = val[0] ? THEN_BODY_INDEX : ELSE_BODY_INDEX; diff --git a/src/core/tests/type_prop/if.cpp b/src/core/tests/type_prop/if.cpp index bfa211eb782..1c27923c4e3 100644 --- a/src/core/tests/type_prop/if.cpp +++ b/src/core/tests/type_prop/if.cpp @@ -44,12 +44,12 @@ TEST(type_prop, if_simple_test) { Shape out0_shape{32, 40, 10}; auto sh = result0->get_output_shape(0); EXPECT_EQ(sh, out0_shape); - // Check that If validation validates both bodies + // Check that If validation when condition is constant validates single body convert_then_op->set_convert_element_type(ov::element::f16); convert_else_op->set_convert_element_type(ov::element::f16); if_op->validate_and_infer_types(); - EXPECT_EQ(else_op_res->get_element_type(), ov::element::f16); EXPECT_EQ(then_op_res->get_element_type(), ov::element::f16); + EXPECT_EQ(else_op_res->get_element_type(), ov::element::f32); } TEST(type_prop, if_non_const_condition_test) { @@ -340,7 +340,7 @@ TEST(type_prop, if_element_type_dynamic) { // That which we iterate over auto X = make_shared(element::f16, Shape{32, 40, 10}); auto Y = make_shared(element::f16, Shape{32, 40, 10}); - auto cond = std::make_shared(ngraph::element::boolean, ngraph::Shape{1}, true); + auto cond = std::make_shared(ngraph::element::boolean, ngraph::Shape{1}, false); // Set up the cell body, a function from (Xi, Yi) -> (Zo) // Body parameters @@ -367,8 +367,49 @@ TEST(type_prop, if_element_type_dynamic) { Shape out0_shape{32, 40, 10}; auto sh = result0->get_output_shape(0); EXPECT_EQ(sh, out0_shape); - // Check that If validation validates both bodies + // Check that If validation when condition is constant validates single body if_op->validate_and_infer_types(); + EXPECT_EQ(then_op_res->get_element_type(), ov::element::dynamic); + EXPECT_EQ(else_op_res->get_element_type(), ov::element::f16); +} + +TEST(type_prop, if_invalid_false_body) { + // That which we iterate over + auto X = make_shared(element::f16, Shape{32, 40, 10}); + auto Y = make_shared(element::f16, Shape{32, 40}); + auto cond = std::make_shared(ngraph::element::boolean, ngraph::Shape{1}, false); + + // Set up the cell body, a function from (Xi, Yi) -> (Zo) + // Body parameters + auto Xt = make_shared(element::dynamic, PartialShape::dynamic()); + auto Yt = make_shared(element::dynamic, PartialShape::dynamic()); + auto Xe = make_shared(element::dynamic, PartialShape::dynamic()); + auto Ye = make_shared(element::dynamic, PartialShape::dynamic()); + // Body + auto axes_4d = opset5::Constant::create(element::i32, ngraph::Shape{2}, {2, 3}); + auto then_reduce_op = std::make_shared(Xt, Yt); + auto then_op = std::make_shared(then_reduce_op, Yt); + auto then_op_res = std::make_shared(then_op); + + auto then_body = make_shared(OutputVector{then_op_res}, ParameterVector{Xt, Yt}); + + auto axes_3d = opset5::Constant::create(element::i32, ngraph::Shape{1}, {2}); + auto else_reduce_op = std::make_shared(Xe, axes_3d); + auto else_op = std::make_shared(else_reduce_op, Ye); + auto else_op_res = std::make_shared(else_op); + auto else_body = make_shared(OutputVector{else_op_res}, ParameterVector{Xe, Ye}); + auto if_op = make_shared(cond); + if_op->set_then_body(then_body); + if_op->set_else_body(else_body); + if_op->set_input(X, Xt, Xe); + if_op->set_input(Y, Yt, Ye); + auto res = if_op->set_output(then_op_res, else_op_res); + auto result0 = make_shared(res); + Shape out0_shape{32, 40}; + auto sh = result0->get_output_shape(0); + EXPECT_EQ(sh, out0_shape); + // Check that If validation when condition is constant validates single body + if_op->validate_and_infer_types(); + EXPECT_EQ(then_op_res->get_element_type(), ov::element::dynamic); EXPECT_EQ(else_op_res->get_element_type(), ov::element::f16); - EXPECT_EQ(then_op_res->get_element_type(), ov::element::f16); }