Handle Reshape in SplitSqueezeConcatFusion (#20345)
* Handle Reshape in SplitSqueezeConcatFusion Ticket: CVS-122455 * move check for squeeze/reshape * add some comments * review comments * add use_shapes flag to SplitSqueezeConcatFusion
This commit is contained in:
parent
57571d36e6
commit
48c9598892
@ -27,5 +27,5 @@ class TRANSFORMATIONS_API SplitSqueezeConcatFusion;
|
||||
class ov::pass::SplitSqueezeConcatFusion : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("SplitSqueezeConcatFusion", "0");
|
||||
SplitSqueezeConcatFusion();
|
||||
SplitSqueezeConcatFusion(bool use_shapes);
|
||||
};
|
||||
|
@ -170,7 +170,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
|
||||
// SplitSqueezeConcatFusion should work in same GraphRewrite as TransposesSinking,
|
||||
// because it replaces pattern that may contain Transposes which must be optimized before
|
||||
// the transformation and it also inserts Transpose that can be optimized by TransposeSinking
|
||||
ADD_MATCHER(transpose_sinking, SplitSqueezeConcatFusion)
|
||||
ADD_MATCHER(transpose_sinking, SplitSqueezeConcatFusion, m_use_shapes)
|
||||
|
||||
REGISTER_PASS(manager, TransposeMatMul)
|
||||
|
||||
|
@ -18,7 +18,9 @@
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
|
||||
ov::pass::SplitSqueezeConcatFusion::SplitSqueezeConcatFusion() {
|
||||
static bool is_axis_squeezed_by_node(const std::shared_ptr<ov::Node>& squeeze_node, int64_t axis, bool use_shapes);
|
||||
|
||||
ov::pass::SplitSqueezeConcatFusion::SplitSqueezeConcatFusion(bool use_shapes) {
|
||||
MATCHER_SCOPE(SplitSqueezeConcatFusion);
|
||||
// Detect only concat, because we don't know how many inputs will go into concat
|
||||
auto concat_pattern = ov::pass::pattern::wrap_type<ov::op::v0::Concat>();
|
||||
@ -32,66 +34,51 @@ ov::pass::SplitSqueezeConcatFusion::SplitSqueezeConcatFusion() {
|
||||
|
||||
NodeVector nodes_to_delete{concat};
|
||||
|
||||
int64_t axis_value = 0;
|
||||
std::shared_ptr<ov::op::v1::Split> split;
|
||||
int64_t split_axis = 0;
|
||||
|
||||
const auto& concat_inputs = concat->input_values();
|
||||
if (concat_inputs.empty())
|
||||
return false;
|
||||
for (size_t i = 0; i < concat_inputs.size(); i++) {
|
||||
auto squeeze = std::dynamic_pointer_cast<ov::op::v0::Squeeze>(concat_inputs[i].get_node_shared_ptr());
|
||||
if (!squeeze)
|
||||
for (size_t i = 0; i < concat->get_input_size(); i++) {
|
||||
auto squeeze_node = concat->get_input_node_shared_ptr(i);
|
||||
if (!ov::is_type<ov::op::v0::Squeeze>(squeeze_node) && !ov::is_type<ov::op::v1::Reshape>(squeeze_node))
|
||||
return false;
|
||||
|
||||
nodes_to_delete.push_back(squeeze);
|
||||
|
||||
auto split_to_check = std::dynamic_pointer_cast<ov::op::v1::Split>(squeeze->get_input_node_shared_ptr(0));
|
||||
auto split_to_check =
|
||||
std::dynamic_pointer_cast<ov::op::v1::Split>(squeeze_node->get_input_node_shared_ptr(0));
|
||||
if (!split_to_check)
|
||||
return false;
|
||||
std::vector<int64_t> squeeze_axes_vec;
|
||||
if (squeeze->get_input_size() < 2) {
|
||||
const auto& shape = squeeze->get_input_partial_shape(0);
|
||||
if (shape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
if (shape[i].get_length() == 1)
|
||||
squeeze_axes_vec.push_back(static_cast<int64_t>(i));
|
||||
}
|
||||
|
||||
} else {
|
||||
auto squeeze_axes =
|
||||
std::dynamic_pointer_cast<ov::op::v0::Constant>(squeeze->get_input_node_shared_ptr(1));
|
||||
if (!squeeze_axes)
|
||||
return false;
|
||||
squeeze_axes_vec = squeeze_axes->cast_vector<int64_t>();
|
||||
}
|
||||
|
||||
if (squeeze_axes_vec.size() != 1)
|
||||
return false;
|
||||
|
||||
if (i == 0) {
|
||||
axis_value = squeeze_axes_vec[0];
|
||||
nodes_to_delete.push_back(split_to_check);
|
||||
split = split_to_check;
|
||||
} else if (axis_value != squeeze_axes_vec[0] || split_to_check != split) {
|
||||
auto split_axis_node =
|
||||
std::dynamic_pointer_cast<ov::op::v0::Constant>(split->get_input_node_shared_ptr(1));
|
||||
if (!split_axis_node)
|
||||
return false;
|
||||
auto axis_vec = split_axis_node->cast_vector<int64_t>();
|
||||
if (axis_vec.size() != 1)
|
||||
return false;
|
||||
split_axis = axis_vec[0];
|
||||
if (split_axis < 0) {
|
||||
auto rank = split->get_output_partial_shape(0).rank();
|
||||
if (rank.is_dynamic())
|
||||
return false;
|
||||
split_axis += rank.get_length();
|
||||
}
|
||||
} else if (split_to_check != split) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto split_output = squeeze->input_value(0);
|
||||
if (!is_axis_squeezed_by_node(squeeze_node, split_axis, use_shapes)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
nodes_to_delete.push_back(squeeze_node);
|
||||
|
||||
auto split_output = squeeze_node->input_value(0);
|
||||
if (split_output.get_target_inputs().size() != 1 || split_output.get_index() != i)
|
||||
return false;
|
||||
}
|
||||
|
||||
if (split->get_num_splits() != concat_inputs.size())
|
||||
return false;
|
||||
|
||||
auto split_axis = std::dynamic_pointer_cast<ov::op::v0::Constant>(split->input_value(1).get_node_shared_ptr());
|
||||
if (!split_axis)
|
||||
return false;
|
||||
|
||||
auto axis_vec = split_axis->cast_vector<int64_t>();
|
||||
if (axis_vec.size() != 1 || axis_value != axis_vec[0])
|
||||
if (split->get_num_splits() != concat->get_input_size())
|
||||
return false;
|
||||
|
||||
auto input = split->input_value(0);
|
||||
@ -102,8 +89,8 @@ ov::pass::SplitSqueezeConcatFusion::SplitSqueezeConcatFusion() {
|
||||
return false;
|
||||
std::vector<int64_t> order(rank.get_length());
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
order.erase(order.begin() + axis_value);
|
||||
order.insert(order.begin() + concat_axis, axis_value);
|
||||
order.erase(order.begin() + split_axis);
|
||||
order.insert(order.begin() + concat_axis, split_axis);
|
||||
|
||||
auto transpose_order = ov::op::v0::Constant::create(element::i64, {(size_t)rank.get_length()}, order);
|
||||
auto transpose = register_new_node<ov::op::v1::Transpose>(input, transpose_order);
|
||||
@ -120,3 +107,67 @@ ov::pass::SplitSqueezeConcatFusion::SplitSqueezeConcatFusion() {
|
||||
auto m = std::make_shared<ov::pass::pattern::Matcher>(concat_pattern, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
bool is_axis_squeezed_by_node(const std::shared_ptr<ov::Node>& squeeze_node, int64_t axis, bool use_shapes) {
|
||||
const auto& input_shape = squeeze_node->get_input_partial_shape(0);
|
||||
const auto& output_shape = squeeze_node->get_output_partial_shape(0);
|
||||
if (input_shape.rank().is_dynamic() || output_shape.rank().is_dynamic())
|
||||
return false;
|
||||
|
||||
auto input_rank = input_shape.rank().get_length();
|
||||
auto output_rank = output_shape.rank().get_length();
|
||||
// check if output_rank == input_rank - 1
|
||||
// to make sure the node actually squeezes a dimension
|
||||
if (input_rank != output_rank + 1)
|
||||
return false;
|
||||
|
||||
// check if squeezed dimension equals to 1
|
||||
if (input_shape[axis].is_dynamic() || input_shape[axis] != 1)
|
||||
return false;
|
||||
|
||||
if (ov::is_type<ov::op::v1::Reshape>(squeeze_node)) {
|
||||
if (!use_shapes)
|
||||
return false;
|
||||
// clang-format off
|
||||
// check if the dimensions surrounding squeezed axis match
|
||||
// function returns false if input_shape[:axis] != output_shape[:axis] or input_shape[(axis + 1):] != output_shape[axis:]
|
||||
// clang-format on
|
||||
if (input_shape.is_dynamic() || output_shape.is_dynamic())
|
||||
return false;
|
||||
|
||||
if (!std::equal(input_shape.begin(), input_shape.begin() + axis, output_shape.begin()))
|
||||
return false;
|
||||
|
||||
if (!std::equal(input_shape.begin() + axis + 1, input_shape.end(), output_shape.begin() + axis))
|
||||
return false;
|
||||
} else {
|
||||
if (squeeze_node->get_input_size() == 1) {
|
||||
// The case when Squeeze has only one input so every dimension == 1 is squeezed
|
||||
if (input_shape.is_dynamic())
|
||||
return false;
|
||||
size_t num_squeezed_axes = 0;
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
if (input_shape[i].get_length() == 1) {
|
||||
num_squeezed_axes++;
|
||||
if (num_squeezed_axes > 1)
|
||||
return false;
|
||||
if (static_cast<int64_t>(i) != axis)
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// The second Squeeze input has explicit axes
|
||||
auto constant = ov::as_type_ptr<ov::op::v0::Constant>(squeeze_node->get_input_node_shared_ptr(1));
|
||||
if (!constant)
|
||||
return false;
|
||||
if (ov::shape_size(constant->get_shape()) != 1)
|
||||
return false;
|
||||
auto squeezed_axis = constant->cast_vector<int64_t>()[0];
|
||||
squeezed_axis = squeezed_axis < 0 ? squeezed_axis + input_rank : squeezed_axis;
|
||||
if (axis != squeezed_axis)
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -37,7 +37,7 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusion) {
|
||||
|
||||
model = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
|
||||
|
||||
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>();
|
||||
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>(false);
|
||||
}
|
||||
|
||||
{
|
||||
@ -69,7 +69,7 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusionSqueezeWithoutAxesInput) {
|
||||
|
||||
model = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
|
||||
|
||||
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>();
|
||||
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>(false);
|
||||
}
|
||||
|
||||
{
|
||||
@ -103,7 +103,7 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeCaseNotAllSplitOutp
|
||||
model = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
|
||||
model_ref = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
|
||||
|
||||
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>();
|
||||
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>(false);
|
||||
}
|
||||
|
||||
{
|
||||
@ -144,7 +144,7 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeCaseSplitOutputsGoI
|
||||
model = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
|
||||
model_ref = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
|
||||
|
||||
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>();
|
||||
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>(false);
|
||||
}
|
||||
|
||||
{
|
||||
@ -185,7 +185,7 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeCaseSplitAxisDiffer
|
||||
model = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
|
||||
model_ref = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
|
||||
|
||||
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>();
|
||||
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>(false);
|
||||
}
|
||||
|
||||
{
|
||||
@ -222,6 +222,72 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeSqueezeWithoutAxesI
|
||||
|
||||
model = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
|
||||
|
||||
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>();
|
||||
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>(false);
|
||||
}
|
||||
}
|
||||
|
||||
struct SplitReshapeConcatFusionParam {
|
||||
int num_splits;
|
||||
int split_axis;
|
||||
Shape input_shape;
|
||||
std::vector<int> reshaped_shape;
|
||||
int concat_axis;
|
||||
std::vector<int> transpose_order;
|
||||
bool can_fuse;
|
||||
};
|
||||
|
||||
class SplitReshapeConcatFusion : public TransformationTestsF,
|
||||
public testing::WithParamInterface<SplitReshapeConcatFusionParam> {};
|
||||
|
||||
TEST_P(SplitReshapeConcatFusion, SplitSqueezeConcatFusion) {
|
||||
auto params = GetParam();
|
||||
ASSERT_EQ(0, params.input_shape[params.split_axis] % params.num_splits);
|
||||
|
||||
{
|
||||
auto input = std::make_shared<opset7::Parameter>(element::f32, params.input_shape);
|
||||
auto split_axis_node = opset7::Constant::create(element::i64, Shape{}, {params.split_axis});
|
||||
auto split = std::make_shared<opset7::Split>(input, split_axis_node, params.num_splits);
|
||||
OutputVector squeeze_vec;
|
||||
squeeze_vec.reserve(params.num_splits);
|
||||
auto reshaped_shape_node =
|
||||
opset7::Constant::create(element::i32, Shape{params.reshaped_shape.size()}, params.reshaped_shape);
|
||||
for (int i = 0; i < params.num_splits; i++) {
|
||||
squeeze_vec.push_back(std::make_shared<opset7::Reshape>(split->output(i), reshaped_shape_node, true));
|
||||
}
|
||||
auto concat = std::make_shared<opset7::Concat>(squeeze_vec, params.concat_axis);
|
||||
model = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
|
||||
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>(true);
|
||||
}
|
||||
|
||||
if (!params.can_fuse) {
|
||||
model_ref = model->clone();
|
||||
} else {
|
||||
auto input = std::make_shared<opset7::Parameter>(element::f32, params.input_shape);
|
||||
auto transpose_order_node =
|
||||
opset7::Constant::create(element::i64, Shape{params.transpose_order.size()}, params.transpose_order);
|
||||
auto transpose = std::make_shared<opset7::Transpose>(input, transpose_order_node);
|
||||
auto reshape_shape = params.input_shape;
|
||||
reshape_shape.erase(reshape_shape.begin() + params.split_axis);
|
||||
reshape_shape[params.concat_axis] *= params.num_splits;
|
||||
auto reshape_shape_node = opset7::Constant::create(element::i64, Shape{reshape_shape.size()}, reshape_shape);
|
||||
auto reshape = std::make_shared<opset7::Reshape>(transpose, reshape_shape_node, false);
|
||||
|
||||
model_ref = std::make_shared<ov::Model>(NodeVector{reshape}, ParameterVector{input});
|
||||
}
|
||||
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
}
|
||||
|
||||
static std::vector<SplitReshapeConcatFusionParam> split_reshape_concat_fusion_params{
|
||||
{4, 2, Shape{3, 1, 4, 1, 5}, {3, 1, 1, 5}, 1, {0, 2, 1, 3, 4}, true},
|
||||
{4, 0, Shape{4, 6, 5}, {6, 5}, 1, {1, 0, 2}, true},
|
||||
{5, 2, Shape{4, 6, 5}, {4, 6}, 0, {2, 0, 1}, true},
|
||||
{2, 2, Shape{3, 1, 4, 5}, {3, 2, 5}, 1, {0, 2, 1, 3}, false},
|
||||
{2, 1, Shape{3, 2, 3, 4, 5}, {3, 3, 5, 4}, 1, {0, 2, 1, 3, 4}, false},
|
||||
{4, 2, Shape{3, 1, 4, 1, 5}, {3, 1, 5}, 1, {0, 2, 1, 3, 4}, false},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransformationTests,
|
||||
SplitReshapeConcatFusion,
|
||||
testing::ValuesIn(split_reshape_concat_fusion_params));
|
||||
|
Loading…
Reference in New Issue
Block a user