[Transformations] Fix unroll if transformation rt info (#20047)

* [Transformations] Fix unroll if transformation rt info

* Update test to check body names
This commit is contained in:
Nadezhda Ageeva 2023-09-26 17:13:52 +04:00 committed by GitHub
parent 519f13a177
commit 9e7938106b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 1 deletions

View File

@ -42,6 +42,9 @@ bool ov::pass::UnrollIf::run_on_model(const std::shared_ptr<ov::Model>& f) {
auto body = (cond_value[0]) ? if_node->get_then_body() : if_node->get_else_body();
auto input_descriptions = if_node->get_input_descriptions(static_cast<int>(!cond_value[0]));
auto output_descriptions = if_node->get_output_descriptions(static_cast<int>(!cond_value[0]));
// copy rt info before reconnection
for (auto& op : body->get_ops())
copy_runtime_info({op, if_node}, op);
// connect inputs instead of body parameters
for (const auto& input_descr : input_descriptions) {
@ -67,7 +70,6 @@ bool ov::pass::UnrollIf::run_on_model(const std::shared_ptr<ov::Model>& f) {
}
is_applicable = true;
f->add_sinks(body->get_sinks());
copy_runtime_info(if_node, body->get_ops());
}
return is_applicable;
}

View File

@ -18,34 +18,51 @@
#include "openvino/op/variadic_split.hpp"
#include "openvino/pass/constant_folding.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/util/common_util.hpp"
#include "transformations/common_optimizations/push_constant_to_subgraph.hpp"
#include "transformations/init_node_info.hpp"
#include "transformations/rt_info/fused_names_attribute.hpp"
using namespace ov;
using namespace testing;
std::shared_ptr<ov::Model> get_then_body() {
auto Xt = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{3});
Xt->set_friendly_name("Xt");
auto Yt = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{3});
Yt->set_friendly_name("Yt");
auto add_op = std::make_shared<ov::op::v1::Add>(Xt, Yt);
add_op->set_friendly_name("add_op");
auto then_op_result = std::make_shared<ov::op::v0::Result>(add_op);
then_op_result->set_friendly_name("then_op_result");
auto then_body = std::make_shared<ov::Model>(ov::OutputVector{then_op_result}, ov::ParameterVector{Xt, Yt});
ov::pass::InitNodeInfo().run_on_model(then_body);
return then_body;
}
std::shared_ptr<ov::Model> get_else_body() {
auto Xe = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{3});
Xe->set_friendly_name("Xe");
auto Ye = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{3});
Ye->set_friendly_name("Ye");
auto mul_op = std::make_shared<ov::op::v1::Multiply>(Xe, Ye);
mul_op->set_friendly_name("mul_op");
auto else_op_result = std::make_shared<ov::op::v0::Result>(mul_op);
else_op_result->set_friendly_name("else_op_result");
auto else_body = std::make_shared<ov::Model>(ov::OutputVector{else_op_result}, ov::ParameterVector{Xe, Ye});
ov::pass::InitNodeInfo().run_on_model(else_body);
return else_body;
}
std::shared_ptr<ov::Model> create_if_model(bool condition) {
auto X = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{3});
X->set_friendly_name("X");
auto Y = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{3});
Y->set_friendly_name("y");
auto cond = std::make_shared<ov::op::v0::Constant>(ov::element::boolean, ov::Shape{1}, condition);
cond->set_friendly_name("cond");
auto if_op = std::make_shared<ov::op::v8::If>(cond);
if_op->set_friendly_name("if_op");
const auto& then_body = get_then_body();
const auto& else_body = get_else_body();
@ -57,6 +74,7 @@ std::shared_ptr<ov::Model> create_if_model(bool condition) {
if_op->set_input(Y, then_p[1], else_p[1]);
if_op->set_output(then_body->get_results()[0], else_body->get_results()[0]);
auto if_result = std::make_shared<ov::op::v0::Result>(if_op);
if_result->set_friendly_name("if_result");
return std::make_shared<ov::Model>(ov::NodeVector{if_result}, ov::ParameterVector{X, Y});
}
@ -78,6 +96,18 @@ TEST(TransformationTests, UnrollIfCondIsTrue) {
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
for (auto& op : f->get_ops()) {
std::vector<std::string> fused_names = ov::getFusedNamesVector(op);
if (ov::is_type<ov::op::v1::Add>(op)) {
ASSERT_EQ(2, fused_names.size());
ASSERT_TRUE(ov::util::contains(fused_names, "add_op"));
ASSERT_TRUE(ov::util::contains(fused_names, "if_op"));
} else {
ASSERT_EQ(1, fused_names.size());
ASSERT_TRUE(!ov::util::contains(fused_names, "if_op"));
}
}
}
TEST(TransformationTests, UnrollIfCondIsFalse) {
@ -97,6 +127,18 @@ TEST(TransformationTests, UnrollIfCondIsFalse) {
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
for (auto& op : f->get_ops()) {
std::vector<std::string> fused_names = ov::getFusedNamesVector(op);
if (ov::is_type<ov::op::v1::Multiply>(op)) {
ASSERT_EQ(2, fused_names.size());
ASSERT_TRUE(ov::util::contains(fused_names, "mul_op"));
ASSERT_TRUE(ov::util::contains(fused_names, "if_op"));
} else {
ASSERT_EQ(1, fused_names.size());
ASSERT_TRUE(!ov::util::contains(fused_names, "if_op"));
}
}
}
TEST(TransformationTests, UnrollIfWithSplitInput) {