[Transformations] Fix Parameter name override while removing Select node (#16934)
Details: Applies valid node replacement method which avoids Parameter name override Tickets: 101209 Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com>
This commit is contained in:
parent
25058da48f
commit
129670ab1e
@ -31,7 +31,7 @@ ov::pass::SelectWithOneValueCondition::SelectWithOneValueCondition() {
|
||||
NodeRegistry copy_from;
|
||||
NodeRegistry copy_to;
|
||||
auto& pattern_map = m.get_pattern_value_map();
|
||||
auto select_value = pattern_map.at(select_pattern);
|
||||
auto& select_value = pattern_map.at(select_pattern);
|
||||
auto select = std::dynamic_pointer_cast<Select>(select_value.get_node_shared_ptr());
|
||||
if (!select) {
|
||||
return false;
|
||||
@ -70,9 +70,7 @@ ov::pass::SelectWithOneValueCondition::SelectWithOneValueCondition() {
|
||||
|
||||
if (select_shape.is_static() && branch_output_shape.same_scheme(select_shape)) {
|
||||
// Broadcast is not needed if the select shape is exactly the same as the selected branch
|
||||
select->output(0).replace(branch_output);
|
||||
replace_output_update_name(select->output(0), branch_output);
|
||||
branch_output.get_node_shared_ptr()->set_friendly_name(select->get_friendly_name());
|
||||
return replace_output_update_name(select->output(0), branch_output);
|
||||
} else if (select_shape.is_static()) {
|
||||
// if the shape of the selected branch is not the same, it needs the broadcasting
|
||||
NodeRegistry copy_to;
|
||||
|
@ -128,3 +128,21 @@ static const std::vector<SelectWithOneValueConditionParams> params = {
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(SelectWithOneValueConditionTest, SelectWithOneValueConditionTest, ValuesIn(params));
|
||||
|
||||
TEST(TransformationTests, SelectWithOneValueCondition_DontRenameParameter) {
|
||||
auto x = make_shared<Parameter>(f32, Shape{3});
|
||||
x->set_friendly_name("X");
|
||||
auto y = make_shared<Parameter>(f32, Shape{3});
|
||||
auto relu = make_shared<Relu>(y);
|
||||
auto condition = Constant::create(boolean, Shape{3}, {true, true, true});
|
||||
auto select = make_shared<Select>(condition, x, relu);
|
||||
auto abs = make_shared<Abs>(select);
|
||||
auto model = make_shared<Model>(OutputVector{abs}, ParameterVector{x, y});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::SelectWithOneValueCondition>();
|
||||
manager.run_passes(model);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<Select>(model), 0);
|
||||
EXPECT_EQ(model->get_parameters()[0]->get_friendly_name(), "X");
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user