ConcatFusion: check that replacing input has only 1 consumer (#21425)
* ConcatFusion: check that replacing input has only 1 consumer * Add test
This commit is contained in:
parent
ba735c9149
commit
7d2afa4d38
@ -23,7 +23,7 @@ ov::pass::ConcatFusion::ConcatFusion() {
|
||||
auto is_aplicable = false;
|
||||
for (auto input : concat->input_values()) {
|
||||
const auto inp_concat = std::dynamic_pointer_cast<v0::Concat>(input.get_node_shared_ptr());
|
||||
if (inp_concat && inp_concat->get_axis() == axis) {
|
||||
if (inp_concat && inp_concat->get_axis() == axis && inp_concat->output(0).get_target_inputs().size() == 1) {
|
||||
is_aplicable = true;
|
||||
}
|
||||
}
|
||||
@ -40,7 +40,7 @@ ov::pass::ConcatFusion::ConcatFusion() {
|
||||
OutputVector new_inputs;
|
||||
for (auto input : concat->input_values()) {
|
||||
const auto inp_concat = std::dynamic_pointer_cast<v0::Concat>(input.get_node_shared_ptr());
|
||||
if (inp_concat && inp_concat->get_axis() == axis) {
|
||||
if (inp_concat && inp_concat->get_axis() == axis && inp_concat->output(0).get_target_inputs().size() == 1) {
|
||||
const auto inp_concat_inps = inp_concat->input_values();
|
||||
new_inputs.insert(new_inputs.end(), inp_concat_inps.begin(), inp_concat_inps.end());
|
||||
} else {
|
||||
|
@ -32,3 +32,24 @@ TEST_F(TransformationTestsF, ConcatFusedToConcat) {
|
||||
model_ref = std::make_shared<Model>(ResultVector{result}, ParameterVector{data, data2});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConcatWithSeveralConsumersNotFused) {
|
||||
{
|
||||
auto data = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 3, 14, 14});
|
||||
auto concat1 = std::make_shared<opset13::Concat>(OutputVector{data, data}, 1);
|
||||
auto concat2 = std::make_shared<opset13::Concat>(OutputVector{concat1, data}, 1);
|
||||
auto mul = std::make_shared<opset13::Multiply>(concat1, concat1);
|
||||
auto concat3 = std::make_shared<opset13::Concat>(OutputVector{mul, concat2}, 1);
|
||||
auto result = std::make_shared<opset13::Result>(concat3);
|
||||
model = std::make_shared<Model>(ResultVector{result}, ParameterVector{data});
|
||||
manager.register_pass<pass::ConcatFusion>();
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 3, 14, 14});
|
||||
auto concat1 = std::make_shared<opset13::Concat>(OutputVector{data, data}, 1);
|
||||
auto mul = std::make_shared<opset13::Multiply>(concat1, concat1);
|
||||
auto concat3 = std::make_shared<opset13::Concat>(OutputVector{mul, concat1, data}, 1);
|
||||
auto result = std::make_shared<opset13::Result>(concat3);
|
||||
model_ref = std::make_shared<Model>(ResultVector{result}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user