Add transformation to fuse Concat sequence (#21310)
* Add transformation to fuse Concat sequence * Update src/common/transformations/src/transformations/common_optimizations/concat_fusion.cpp
This commit is contained in:
parent
685ac0d0a1
commit
f90a4b9d31
@ -0,0 +1,30 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "openvino/pass/graph_rewrite.hpp"
|
||||||
|
#include "transformations_visibility.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
class TRANSFORMATIONS_API ConcatFusion;
|
||||||
|
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @ingroup ie_transformation_common_api
|
||||||
|
* @brief ConcatFusion transformation fuses sequence of Concats
|
||||||
|
*/
|
||||||
|
|
||||||
|
class ov::pass::ConcatFusion : public ov::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("ConcatFusion", "0");
|
||||||
|
ConcatFusion();
|
||||||
|
};
|
@ -14,6 +14,7 @@
|
|||||||
#include "transformations/common_optimizations/binarize_weights.hpp"
|
#include "transformations/common_optimizations/binarize_weights.hpp"
|
||||||
#include "transformations/common_optimizations/broadcast_elementwise_fusion.hpp"
|
#include "transformations/common_optimizations/broadcast_elementwise_fusion.hpp"
|
||||||
#include "transformations/common_optimizations/clamp_fusion.hpp"
|
#include "transformations/common_optimizations/clamp_fusion.hpp"
|
||||||
|
#include "transformations/common_optimizations/concat_fusion.hpp"
|
||||||
#include "transformations/common_optimizations/concat_reduce_fusion.hpp"
|
#include "transformations/common_optimizations/concat_reduce_fusion.hpp"
|
||||||
#include "transformations/common_optimizations/conv_mul_fusion.hpp"
|
#include "transformations/common_optimizations/conv_mul_fusion.hpp"
|
||||||
#include "transformations/common_optimizations/conv_to_binary_conv.hpp"
|
#include "transformations/common_optimizations/conv_to_binary_conv.hpp"
|
||||||
@ -219,6 +220,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
|
|||||||
REGISTER_PASS(manager, ConvertInterpolate11ToInterpolate4)
|
REGISTER_PASS(manager, ConvertInterpolate11ToInterpolate4)
|
||||||
REGISTER_PASS(manager, ConvertPad12ToPad1)
|
REGISTER_PASS(manager, ConvertPad12ToPad1)
|
||||||
REGISTER_PASS(manager, ConvertScatterElementsUpdate12ToScatterElementsUpdate3)
|
REGISTER_PASS(manager, ConvertScatterElementsUpdate12ToScatterElementsUpdate3)
|
||||||
|
REGISTER_PASS(manager, ConcatFusion)
|
||||||
|
|
||||||
auto fq_fusions = manager.register_pass<GraphRewrite>();
|
auto fq_fusions = manager.register_pass<GraphRewrite>();
|
||||||
ADD_MATCHER(fq_fusions, FakeQuantizeMulFusion)
|
ADD_MATCHER(fq_fusions, FakeQuantizeMulFusion)
|
||||||
|
@ -0,0 +1,60 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "transformations/common_optimizations/concat_fusion.hpp"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "itt.hpp"
|
||||||
|
#include "openvino/core/rt_info.hpp"
|
||||||
|
#include "openvino/op/concat.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||||
|
#include "transformations/utils/utils.hpp"
|
||||||
|
|
||||||
|
using namespace ov::op;
|
||||||
|
|
||||||
|
ov::pass::ConcatFusion::ConcatFusion() {
|
||||||
|
MATCHER_SCOPE(ConcatFusion);
|
||||||
|
auto has_same_axis_concat_input = [](const Output<Node>& output) {
|
||||||
|
const auto& concat = std::dynamic_pointer_cast<v0::Concat>(output.get_node_shared_ptr());
|
||||||
|
const auto axis = concat->get_axis();
|
||||||
|
auto is_aplicable = false;
|
||||||
|
for (auto input : concat->input_values()) {
|
||||||
|
const auto inp_concat = std::dynamic_pointer_cast<v0::Concat>(input.get_node_shared_ptr());
|
||||||
|
if (inp_concat && inp_concat->get_axis() == axis) {
|
||||||
|
is_aplicable = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return is_aplicable;
|
||||||
|
};
|
||||||
|
auto concat_pattern = pattern::wrap_type<v0::Concat>(has_same_axis_concat_input);
|
||||||
|
|
||||||
|
ov::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||||
|
const auto& pattern_map = m.get_pattern_map();
|
||||||
|
|
||||||
|
const auto& concat = std::dynamic_pointer_cast<v0::Concat>(pattern_map.at(concat_pattern));
|
||||||
|
const auto axis = concat->get_axis();
|
||||||
|
|
||||||
|
OutputVector new_inputs;
|
||||||
|
for (auto input : concat->input_values()) {
|
||||||
|
const auto inp_concat = std::dynamic_pointer_cast<v0::Concat>(input.get_node_shared_ptr());
|
||||||
|
if (inp_concat && inp_concat->get_axis() == axis) {
|
||||||
|
const auto inp_concat_inps = inp_concat->input_values();
|
||||||
|
new_inputs.insert(new_inputs.end(), inp_concat_inps.begin(), inp_concat_inps.end());
|
||||||
|
} else {
|
||||||
|
new_inputs.push_back(input);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto new_concat = std::make_shared<v0::Concat>(new_inputs, axis);
|
||||||
|
replace_node(concat, new_concat);
|
||||||
|
new_concat->set_friendly_name(concat->get_friendly_name());
|
||||||
|
copy_runtime_info(concat, new_concat);
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<ov::pass::pattern::Matcher>(concat_pattern, matcher_name);
|
||||||
|
this->register_matcher(m, callback);
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "transformations/common_optimizations/concat_fusion.hpp"
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "common_test_utils/ov_test_utils.hpp"
|
||||||
|
#include "openvino/opsets/opset13.hpp"
|
||||||
|
|
||||||
|
using namespace testing;
|
||||||
|
using namespace ov;
|
||||||
|
|
||||||
|
TEST_F(TransformationTestsF, ConcatFusedToConcat) {
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 3, 14, 14});
|
||||||
|
auto data2 = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 3, 7, 14});
|
||||||
|
auto concat1 = std::make_shared<opset13::Concat>(OutputVector{data, data}, 1);
|
||||||
|
auto concat2 = std::make_shared<opset13::Concat>(OutputVector{data2, data2}, 2);
|
||||||
|
auto concat3 = std::make_shared<opset13::Concat>(OutputVector{concat1, concat2, data}, 1);
|
||||||
|
auto result = std::make_shared<opset13::Result>(concat3);
|
||||||
|
model = std::make_shared<Model>(ResultVector{result}, ParameterVector{data, data2});
|
||||||
|
manager.register_pass<pass::ConcatFusion>();
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 3, 14, 14});
|
||||||
|
auto data2 = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1, 3, 7, 14});
|
||||||
|
auto concat2 = std::make_shared<opset13::Concat>(OutputVector{data2, data2}, 2);
|
||||||
|
auto concat3 = std::make_shared<opset13::Concat>(OutputVector{data, data, concat2, data}, 1);
|
||||||
|
auto result = std::make_shared<opset13::Result>(concat3);
|
||||||
|
model_ref = std::make_shared<Model>(ResultVector{result}, ParameterVector{data, data2});
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user