ConcatFusion: check that replacing input has only 1 consumer (#21425)

* ConcatFusion: check that replacing input has only 1 consumer

* Add test
This commit is contained in:
Maxim Vafin 2023-12-01 17:58:19 +01:00 committed by GitHub
parent ba735c9149
commit 7d2afa4d38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 2 deletions

View File

@ -23,7 +23,7 @@ ov::pass::ConcatFusion::ConcatFusion() {
auto is_aplicable = false;
for (auto input : concat->input_values()) {
const auto inp_concat = std::dynamic_pointer_cast<v0::Concat>(input.get_node_shared_ptr());
if (inp_concat && inp_concat->get_axis() == axis) {
if (inp_concat && inp_concat->get_axis() == axis && inp_concat->output(0).get_target_inputs().size() == 1) {
is_aplicable = true;
}
}
@ -40,7 +40,7 @@ ov::pass::ConcatFusion::ConcatFusion() {
OutputVector new_inputs;
for (auto input : concat->input_values()) {
const auto inp_concat = std::dynamic_pointer_cast<v0::Concat>(input.get_node_shared_ptr());
if (inp_concat && inp_concat->get_axis() == axis) {
if (inp_concat && inp_concat->get_axis() == axis && inp_concat->output(0).get_target_inputs().size() == 1) {
const auto inp_concat_inps = inp_concat->input_values();
new_inputs.insert(new_inputs.end(), inp_concat_inps.begin(), inp_concat_inps.end());
} else {

View File

@ -32,3 +32,24 @@ TEST_F(TransformationTestsF, ConcatFusedToConcat) {
model_ref = std::make_shared<Model>(ResultVector{result}, ParameterVector{data, data2});
}
}
TEST_F(TransformationTestsF, ConcatWithSeveralConsumersNotFused) {
{
auto data = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 3, 14, 14});
auto concat1 = std::make_shared<opset13::Concat>(OutputVector{data, data}, 1);
auto concat2 = std::make_shared<opset13::Concat>(OutputVector{concat1, data}, 1);
auto mul = std::make_shared<opset13::Multiply>(concat1, concat1);
auto concat3 = std::make_shared<opset13::Concat>(OutputVector{mul, concat2}, 1);
auto result = std::make_shared<opset13::Result>(concat3);
model = std::make_shared<Model>(ResultVector{result}, ParameterVector{data});
manager.register_pass<pass::ConcatFusion>();
}
{
auto data = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 3, 14, 14});
auto concat1 = std::make_shared<opset13::Concat>(OutputVector{data, data}, 1);
auto mul = std::make_shared<opset13::Multiply>(concat1, concat1);
auto concat3 = std::make_shared<opset13::Concat>(OutputVector{mul, concat1, data}, 1);
auto result = std::make_shared<opset13::Result>(concat3);
model_ref = std::make_shared<Model>(ResultVector{result}, ParameterVector{data});
}
}