Add transformation to fuse Concat sequence (#21310)
* Add transformation to fuse Concat sequence * Update src/common/transformations/src/transformations/common_optimizations/concat_fusion.cpp
This commit is contained in:
parent
685ac0d0a1
commit
f90a4b9d31
@ -0,0 +1,30 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API ConcatFusion;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief ConcatFusion transformation fuses sequence of Concats
|
||||
*/
|
||||
|
||||
class ov::pass::ConcatFusion : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ConcatFusion", "0");
|
||||
ConcatFusion();
|
||||
};
|
@ -14,6 +14,7 @@
|
||||
#include "transformations/common_optimizations/binarize_weights.hpp"
|
||||
#include "transformations/common_optimizations/broadcast_elementwise_fusion.hpp"
|
||||
#include "transformations/common_optimizations/clamp_fusion.hpp"
|
||||
#include "transformations/common_optimizations/concat_fusion.hpp"
|
||||
#include "transformations/common_optimizations/concat_reduce_fusion.hpp"
|
||||
#include "transformations/common_optimizations/conv_mul_fusion.hpp"
|
||||
#include "transformations/common_optimizations/conv_to_binary_conv.hpp"
|
||||
@ -219,6 +220,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
|
||||
REGISTER_PASS(manager, ConvertInterpolate11ToInterpolate4)
|
||||
REGISTER_PASS(manager, ConvertPad12ToPad1)
|
||||
REGISTER_PASS(manager, ConvertScatterElementsUpdate12ToScatterElementsUpdate3)
|
||||
REGISTER_PASS(manager, ConcatFusion)
|
||||
|
||||
auto fq_fusions = manager.register_pass<GraphRewrite>();
|
||||
ADD_MATCHER(fq_fusions, FakeQuantizeMulFusion)
|
||||
|
@ -0,0 +1,60 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/concat_fusion.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
ov::pass::ConcatFusion::ConcatFusion() {
|
||||
MATCHER_SCOPE(ConcatFusion);
|
||||
auto has_same_axis_concat_input = [](const Output<Node>& output) {
|
||||
const auto& concat = std::dynamic_pointer_cast<v0::Concat>(output.get_node_shared_ptr());
|
||||
const auto axis = concat->get_axis();
|
||||
auto is_aplicable = false;
|
||||
for (auto input : concat->input_values()) {
|
||||
const auto inp_concat = std::dynamic_pointer_cast<v0::Concat>(input.get_node_shared_ptr());
|
||||
if (inp_concat && inp_concat->get_axis() == axis) {
|
||||
is_aplicable = true;
|
||||
}
|
||||
}
|
||||
return is_aplicable;
|
||||
};
|
||||
auto concat_pattern = pattern::wrap_type<v0::Concat>(has_same_axis_concat_input);
|
||||
|
||||
ov::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_map();
|
||||
|
||||
const auto& concat = std::dynamic_pointer_cast<v0::Concat>(pattern_map.at(concat_pattern));
|
||||
const auto axis = concat->get_axis();
|
||||
|
||||
OutputVector new_inputs;
|
||||
for (auto input : concat->input_values()) {
|
||||
const auto inp_concat = std::dynamic_pointer_cast<v0::Concat>(input.get_node_shared_ptr());
|
||||
if (inp_concat && inp_concat->get_axis() == axis) {
|
||||
const auto inp_concat_inps = inp_concat->input_values();
|
||||
new_inputs.insert(new_inputs.end(), inp_concat_inps.begin(), inp_concat_inps.end());
|
||||
} else {
|
||||
new_inputs.push_back(input);
|
||||
}
|
||||
}
|
||||
|
||||
auto new_concat = std::make_shared<v0::Concat>(new_inputs, axis);
|
||||
replace_node(concat, new_concat);
|
||||
new_concat->set_friendly_name(concat->get_friendly_name());
|
||||
copy_runtime_info(concat, new_concat);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ov::pass::pattern::Matcher>(concat_pattern, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/concat_fusion.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common_test_utils/ov_test_utils.hpp"
|
||||
#include "openvino/opsets/opset13.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ov;
|
||||
|
||||
TEST_F(TransformationTestsF, ConcatFusedToConcat) {
|
||||
{
|
||||
auto data = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 3, 14, 14});
|
||||
auto data2 = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 3, 7, 14});
|
||||
auto concat1 = std::make_shared<opset13::Concat>(OutputVector{data, data}, 1);
|
||||
auto concat2 = std::make_shared<opset13::Concat>(OutputVector{data2, data2}, 2);
|
||||
auto concat3 = std::make_shared<opset13::Concat>(OutputVector{concat1, concat2, data}, 1);
|
||||
auto result = std::make_shared<opset13::Result>(concat3);
|
||||
model = std::make_shared<Model>(ResultVector{result}, ParameterVector{data, data2});
|
||||
manager.register_pass<pass::ConcatFusion>();
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 3, 14, 14});
|
||||
auto data2 = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 3, 7, 14});
|
||||
auto concat2 = std::make_shared<opset13::Concat>(OutputVector{data2, data2}, 2);
|
||||
auto concat3 = std::make_shared<opset13::Concat>(OutputVector{data, data, concat2, data}, 1);
|
||||
auto result = std::make_shared<opset13::Result>(concat3);
|
||||
model_ref = std::make_shared<Model>(ResultVector{result}, ParameterVector{data, data2});
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user