Add transformation to fuse Concat sequence (#21310)

* Add transformation to fuse Concat sequence

* Update src/common/transformations/src/transformations/common_optimizations/concat_fusion.cpp
This commit is contained in:
Maxim Vafin 2023-11-28 10:34:27 +01:00 committed by GitHub
parent 685ac0d0a1
commit f90a4b9d31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 126 additions and 0 deletions

View File

@ -0,0 +1,30 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <vector>
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API ConcatFusion;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief ConcatFusion transformation fuses sequence of Concats
*/
class ov::pass::ConcatFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ConcatFusion", "0");
ConcatFusion();
};

View File

@ -14,6 +14,7 @@
#include "transformations/common_optimizations/binarize_weights.hpp"
#include "transformations/common_optimizations/broadcast_elementwise_fusion.hpp"
#include "transformations/common_optimizations/clamp_fusion.hpp"
#include "transformations/common_optimizations/concat_fusion.hpp"
#include "transformations/common_optimizations/concat_reduce_fusion.hpp"
#include "transformations/common_optimizations/conv_mul_fusion.hpp"
#include "transformations/common_optimizations/conv_to_binary_conv.hpp"
@ -219,6 +220,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
REGISTER_PASS(manager, ConvertInterpolate11ToInterpolate4)
REGISTER_PASS(manager, ConvertPad12ToPad1)
REGISTER_PASS(manager, ConvertScatterElementsUpdate12ToScatterElementsUpdate3)
REGISTER_PASS(manager, ConcatFusion)
auto fq_fusions = manager.register_pass<GraphRewrite>();
ADD_MATCHER(fq_fusions, FakeQuantizeMulFusion)

View File

@ -0,0 +1,60 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/concat_fusion.hpp"
#include <memory>
#include <vector>
#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
using namespace ov::op;
ov::pass::ConcatFusion::ConcatFusion() {
MATCHER_SCOPE(ConcatFusion);
auto has_same_axis_concat_input = [](const Output<Node>& output) {
const auto& concat = std::dynamic_pointer_cast<v0::Concat>(output.get_node_shared_ptr());
const auto axis = concat->get_axis();
auto is_aplicable = false;
for (auto input : concat->input_values()) {
const auto inp_concat = std::dynamic_pointer_cast<v0::Concat>(input.get_node_shared_ptr());
if (inp_concat && inp_concat->get_axis() == axis) {
is_aplicable = true;
}
}
return is_aplicable;
};
auto concat_pattern = pattern::wrap_type<v0::Concat>(has_same_axis_concat_input);
ov::matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_map();
const auto& concat = std::dynamic_pointer_cast<v0::Concat>(pattern_map.at(concat_pattern));
const auto axis = concat->get_axis();
OutputVector new_inputs;
for (auto input : concat->input_values()) {
const auto inp_concat = std::dynamic_pointer_cast<v0::Concat>(input.get_node_shared_ptr());
if (inp_concat && inp_concat->get_axis() == axis) {
const auto inp_concat_inps = inp_concat->input_values();
new_inputs.insert(new_inputs.end(), inp_concat_inps.begin(), inp_concat_inps.end());
} else {
new_inputs.push_back(input);
}
}
auto new_concat = std::make_shared<v0::Concat>(new_inputs, axis);
replace_node(concat, new_concat);
new_concat->set_friendly_name(concat->get_friendly_name());
copy_runtime_info(concat, new_concat);
return true;
};
auto m = std::make_shared<ov::pass::pattern::Matcher>(concat_pattern, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -0,0 +1,34 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/concat_fusion.hpp"
#include <gtest/gtest.h>
#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/opsets/opset13.hpp"
using namespace testing;
using namespace ov;
TEST_F(TransformationTestsF, ConcatFusedToConcat) {
{
auto data = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 3, 14, 14});
auto data2 = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 3, 7, 14});
auto concat1 = std::make_shared<opset13::Concat>(OutputVector{data, data}, 1);
auto concat2 = std::make_shared<opset13::Concat>(OutputVector{data2, data2}, 2);
auto concat3 = std::make_shared<opset13::Concat>(OutputVector{concat1, concat2, data}, 1);
auto result = std::make_shared<opset13::Result>(concat3);
model = std::make_shared<Model>(ResultVector{result}, ParameterVector{data, data2});
manager.register_pass<pass::ConcatFusion>();
}
{
auto data = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 3, 14, 14});
auto data2 = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 3, 7, 14});
auto concat2 = std::make_shared<opset13::Concat>(OutputVector{data2, data2}, 2);
auto concat3 = std::make_shared<opset13::Concat>(OutputVector{data, data, concat2, data}, 1);
auto result = std::make_shared<opset13::Result>(concat3);
model_ref = std::make_shared<Model>(ResultVector{result}, ParameterVector{data, data2});
}
}