[Transformations] Fix unroll if transformation rt info (#20047)
* [Transformations] Fix unroll if transformation rt info * Update test to check body names
This commit is contained in:
parent
519f13a177
commit
9e7938106b
@ -42,6 +42,9 @@ bool ov::pass::UnrollIf::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
||||
auto body = (cond_value[0]) ? if_node->get_then_body() : if_node->get_else_body();
|
||||
auto input_descriptions = if_node->get_input_descriptions(static_cast<int>(!cond_value[0]));
|
||||
auto output_descriptions = if_node->get_output_descriptions(static_cast<int>(!cond_value[0]));
|
||||
// copy rt info before reconnection
|
||||
for (auto& op : body->get_ops())
|
||||
copy_runtime_info({op, if_node}, op);
|
||||
|
||||
// connect inputs instead of body parameters
|
||||
for (const auto& input_descr : input_descriptions) {
|
||||
@ -67,7 +70,6 @@ bool ov::pass::UnrollIf::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
||||
}
|
||||
is_applicable = true;
|
||||
f->add_sinks(body->get_sinks());
|
||||
copy_runtime_info(if_node, body->get_ops());
|
||||
}
|
||||
return is_applicable;
|
||||
}
|
||||
|
@ -18,34 +18,51 @@
|
||||
#include "openvino/op/variadic_split.hpp"
|
||||
#include "openvino/pass/constant_folding.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "transformations/common_optimizations/push_constant_to_subgraph.hpp"
|
||||
#include "transformations/init_node_info.hpp"
|
||||
#include "transformations/rt_info/fused_names_attribute.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace testing;
|
||||
|
||||
std::shared_ptr<ov::Model> get_then_body() {
|
||||
auto Xt = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{3});
|
||||
Xt->set_friendly_name("Xt");
|
||||
auto Yt = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{3});
|
||||
Yt->set_friendly_name("Yt");
|
||||
auto add_op = std::make_shared<ov::op::v1::Add>(Xt, Yt);
|
||||
add_op->set_friendly_name("add_op");
|
||||
auto then_op_result = std::make_shared<ov::op::v0::Result>(add_op);
|
||||
then_op_result->set_friendly_name("then_op_result");
|
||||
auto then_body = std::make_shared<ov::Model>(ov::OutputVector{then_op_result}, ov::ParameterVector{Xt, Yt});
|
||||
ov::pass::InitNodeInfo().run_on_model(then_body);
|
||||
return then_body;
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> get_else_body() {
|
||||
auto Xe = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{3});
|
||||
Xe->set_friendly_name("Xe");
|
||||
auto Ye = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{3});
|
||||
Ye->set_friendly_name("Ye");
|
||||
auto mul_op = std::make_shared<ov::op::v1::Multiply>(Xe, Ye);
|
||||
mul_op->set_friendly_name("mul_op");
|
||||
auto else_op_result = std::make_shared<ov::op::v0::Result>(mul_op);
|
||||
else_op_result->set_friendly_name("else_op_result");
|
||||
auto else_body = std::make_shared<ov::Model>(ov::OutputVector{else_op_result}, ov::ParameterVector{Xe, Ye});
|
||||
ov::pass::InitNodeInfo().run_on_model(else_body);
|
||||
return else_body;
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> create_if_model(bool condition) {
|
||||
auto X = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{3});
|
||||
X->set_friendly_name("X");
|
||||
auto Y = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{3});
|
||||
Y->set_friendly_name("y");
|
||||
auto cond = std::make_shared<ov::op::v0::Constant>(ov::element::boolean, ov::Shape{1}, condition);
|
||||
cond->set_friendly_name("cond");
|
||||
auto if_op = std::make_shared<ov::op::v8::If>(cond);
|
||||
if_op->set_friendly_name("if_op");
|
||||
const auto& then_body = get_then_body();
|
||||
const auto& else_body = get_else_body();
|
||||
|
||||
@ -57,6 +74,7 @@ std::shared_ptr<ov::Model> create_if_model(bool condition) {
|
||||
if_op->set_input(Y, then_p[1], else_p[1]);
|
||||
if_op->set_output(then_body->get_results()[0], else_body->get_results()[0]);
|
||||
auto if_result = std::make_shared<ov::op::v0::Result>(if_op);
|
||||
if_result->set_friendly_name("if_result");
|
||||
|
||||
return std::make_shared<ov::Model>(ov::NodeVector{if_result}, ov::ParameterVector{X, Y});
|
||||
}
|
||||
@ -78,6 +96,18 @@ TEST(TransformationTests, UnrollIfCondIsTrue) {
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
|
||||
for (auto& op : f->get_ops()) {
|
||||
std::vector<std::string> fused_names = ov::getFusedNamesVector(op);
|
||||
if (ov::is_type<ov::op::v1::Add>(op)) {
|
||||
ASSERT_EQ(2, fused_names.size());
|
||||
ASSERT_TRUE(ov::util::contains(fused_names, "add_op"));
|
||||
ASSERT_TRUE(ov::util::contains(fused_names, "if_op"));
|
||||
} else {
|
||||
ASSERT_EQ(1, fused_names.size());
|
||||
ASSERT_TRUE(!ov::util::contains(fused_names, "if_op"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TransformationTests, UnrollIfCondIsFalse) {
|
||||
@ -97,6 +127,18 @@ TEST(TransformationTests, UnrollIfCondIsFalse) {
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
|
||||
for (auto& op : f->get_ops()) {
|
||||
std::vector<std::string> fused_names = ov::getFusedNamesVector(op);
|
||||
if (ov::is_type<ov::op::v1::Multiply>(op)) {
|
||||
ASSERT_EQ(2, fused_names.size());
|
||||
ASSERT_TRUE(ov::util::contains(fused_names, "mul_op"));
|
||||
ASSERT_TRUE(ov::util::contains(fused_names, "if_op"));
|
||||
} else {
|
||||
ASSERT_EQ(1, fused_names.size());
|
||||
ASSERT_TRUE(!ov::util::contains(fused_names, "if_op"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TransformationTests, UnrollIfWithSplitInput) {
|
||||
|
Loading…
Reference in New Issue
Block a user