[Transformations] Fix Parameter name override while removing Select node (#16934)

Details:
Applies valid node replacement method which avoids Parameter name override

Tickets: 101209

Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com>
This commit is contained in:
Tomasz Jankowski 2023-04-14 18:36:25 +02:00 committed by GitHub
parent 25058da48f
commit 129670ab1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 4 deletions

View File

@ -31,7 +31,7 @@ ov::pass::SelectWithOneValueCondition::SelectWithOneValueCondition() {
NodeRegistry copy_from; NodeRegistry copy_from;
NodeRegistry copy_to; NodeRegistry copy_to;
auto& pattern_map = m.get_pattern_value_map(); auto& pattern_map = m.get_pattern_value_map();
auto select_value = pattern_map.at(select_pattern); auto& select_value = pattern_map.at(select_pattern);
auto select = std::dynamic_pointer_cast<Select>(select_value.get_node_shared_ptr()); auto select = std::dynamic_pointer_cast<Select>(select_value.get_node_shared_ptr());
if (!select) { if (!select) {
return false; return false;
@ -70,9 +70,7 @@ ov::pass::SelectWithOneValueCondition::SelectWithOneValueCondition() {
if (select_shape.is_static() && branch_output_shape.same_scheme(select_shape)) { if (select_shape.is_static() && branch_output_shape.same_scheme(select_shape)) {
// Broadcast is not needed if the select shape is exactly the same as the selected branch // Broadcast is not needed if the select shape is exactly the same as the selected branch
select->output(0).replace(branch_output); return replace_output_update_name(select->output(0), branch_output);
replace_output_update_name(select->output(0), branch_output);
branch_output.get_node_shared_ptr()->set_friendly_name(select->get_friendly_name());
} else if (select_shape.is_static()) { } else if (select_shape.is_static()) {
// if the shape of the selected branch is not the same, it needs the broadcasting // if the shape of the selected branch is not the same, it needs the broadcasting
NodeRegistry copy_to; NodeRegistry copy_to;

View File

@ -128,3 +128,21 @@ static const std::vector<SelectWithOneValueConditionParams> params = {
}; };
INSTANTIATE_TEST_SUITE_P(SelectWithOneValueConditionTest, SelectWithOneValueConditionTest, ValuesIn(params)); INSTANTIATE_TEST_SUITE_P(SelectWithOneValueConditionTest, SelectWithOneValueConditionTest, ValuesIn(params));
TEST(TransformationTests, SelectWithOneValueCondition_DontRenameParameter) {
auto x = make_shared<Parameter>(f32, Shape{3});
x->set_friendly_name("X");
auto y = make_shared<Parameter>(f32, Shape{3});
auto relu = make_shared<Relu>(y);
auto condition = Constant::create(boolean, Shape{3}, {true, true, true});
auto select = make_shared<Select>(condition, x, relu);
auto abs = make_shared<Abs>(select);
auto model = make_shared<Model>(OutputVector{abs}, ParameterVector{x, y});
pass::Manager manager;
manager.register_pass<pass::SelectWithOneValueCondition>();
manager.run_passes(model);
ASSERT_EQ(count_ops_of_type<Select>(model), 0);
EXPECT_EQ(model->get_parameters()[0]->get_friendly_name(), "X");
}