From 7d2afa4d38ed6b0a3e3eed6f15fd8dde4e610af5 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Fri, 1 Dec 2023 17:58:19 +0100 Subject: [PATCH] ConcatFusion: check that replacing input has only 1 consumer (#21425) * ConcatFusion: check that replacing input has only 1 consumer * Add test --- .../common_optimizations/concat_fusion.cpp | 4 ++-- .../common_optimizations/concat_fusion.cpp | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/concat_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/concat_fusion.cpp index 2975a6c9fe5..49994d809aa 100644 --- a/src/common/transformations/src/transformations/common_optimizations/concat_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/concat_fusion.cpp @@ -23,7 +23,7 @@ ov::pass::ConcatFusion::ConcatFusion() { auto is_aplicable = false; for (auto input : concat->input_values()) { const auto inp_concat = std::dynamic_pointer_cast(input.get_node_shared_ptr()); - if (inp_concat && inp_concat->get_axis() == axis) { + if (inp_concat && inp_concat->get_axis() == axis && inp_concat->output(0).get_target_inputs().size() == 1) { is_aplicable = true; } } @@ -40,7 +40,7 @@ ov::pass::ConcatFusion::ConcatFusion() { OutputVector new_inputs; for (auto input : concat->input_values()) { const auto inp_concat = std::dynamic_pointer_cast(input.get_node_shared_ptr()); - if (inp_concat && inp_concat->get_axis() == axis) { + if (inp_concat && inp_concat->get_axis() == axis && inp_concat->output(0).get_target_inputs().size() == 1) { const auto inp_concat_inps = inp_concat->input_values(); new_inputs.insert(new_inputs.end(), inp_concat_inps.begin(), inp_concat_inps.end()); } else { diff --git a/src/common/transformations/tests/common_optimizations/concat_fusion.cpp b/src/common/transformations/tests/common_optimizations/concat_fusion.cpp index 1b3031622ee..9e2defae977 100644 --- a/src/common/transformations/tests/common_optimizations/concat_fusion.cpp +++ b/src/common/transformations/tests/common_optimizations/concat_fusion.cpp @@ -32,3 +32,24 @@ TEST_F(TransformationTestsF, ConcatFusedToConcat) { model_ref = std::make_shared(ResultVector{result}, ParameterVector{data, data2}); } } + +TEST_F(TransformationTestsF, ConcatWithSeveralConsumersNotFused) { + { + auto data = std::make_shared(element::f32, PartialShape{1, 3, 14, 14}); + auto concat1 = std::make_shared(OutputVector{data, data}, 1); + auto concat2 = std::make_shared(OutputVector{concat1, data}, 1); + auto mul = std::make_shared(concat1, concat1); + auto concat3 = std::make_shared(OutputVector{mul, concat2}, 1); + auto result = std::make_shared(concat3); + model = std::make_shared(ResultVector{result}, ParameterVector{data}); + manager.register_pass(); + } + { + auto data = std::make_shared(element::f32, PartialShape{1, 3, 14, 14}); + auto concat1 = std::make_shared(OutputVector{data, data}, 1); + auto mul = std::make_shared(concat1, concat1); + auto concat3 = std::make_shared(OutputVector{mul, concat1, data}, 1); + auto result = std::make_shared(concat3); + model_ref = std::make_shared(ResultVector{result}, ParameterVector{data}); + } +}