Handle Reshape in SplitSqueezeConcatFusion (#20345)

* Handle Reshape in SplitSqueezeConcatFusion

Ticket: CVS-122455

* move check for squeeze/reshape

* add some comments

* review comments

* add use_shapes flag to SplitSqueezeConcatFusion
This commit is contained in:
Mateusz Tabaka 2023-10-31 14:05:21 +01:00 committed by GitHub
parent 57571d36e6
commit 48c9598892
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 172 additions and 55 deletions

View File

@ -27,5 +27,5 @@ class TRANSFORMATIONS_API SplitSqueezeConcatFusion;
class ov::pass::SplitSqueezeConcatFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("SplitSqueezeConcatFusion", "0");
SplitSqueezeConcatFusion();
SplitSqueezeConcatFusion(bool use_shapes);
};

View File

@ -170,7 +170,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
// SplitSqueezeConcatFusion should work in same GraphRewrite as TransposesSinking,
// because it replaces pattern that may contain Transposes which must be optimized before
// the transformation and it also inserts Transpose that can be optimized by TransposeSinking
ADD_MATCHER(transpose_sinking, SplitSqueezeConcatFusion)
ADD_MATCHER(transpose_sinking, SplitSqueezeConcatFusion, m_use_shapes)
REGISTER_PASS(manager, TransposeMatMul)

View File

@ -18,7 +18,9 @@
#include "openvino/op/transpose.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
ov::pass::SplitSqueezeConcatFusion::SplitSqueezeConcatFusion() {
static bool is_axis_squeezed_by_node(const std::shared_ptr<ov::Node>& squeeze_node, int64_t axis, bool use_shapes);
ov::pass::SplitSqueezeConcatFusion::SplitSqueezeConcatFusion(bool use_shapes) {
MATCHER_SCOPE(SplitSqueezeConcatFusion);
// Detect only concat, because we don't know how many inputs will go into concat
auto concat_pattern = ov::pass::pattern::wrap_type<ov::op::v0::Concat>();
@ -32,66 +34,51 @@ ov::pass::SplitSqueezeConcatFusion::SplitSqueezeConcatFusion() {
NodeVector nodes_to_delete{concat};
int64_t axis_value = 0;
std::shared_ptr<ov::op::v1::Split> split;
int64_t split_axis = 0;
const auto& concat_inputs = concat->input_values();
if (concat_inputs.empty())
return false;
for (size_t i = 0; i < concat_inputs.size(); i++) {
auto squeeze = std::dynamic_pointer_cast<ov::op::v0::Squeeze>(concat_inputs[i].get_node_shared_ptr());
if (!squeeze)
for (size_t i = 0; i < concat->get_input_size(); i++) {
auto squeeze_node = concat->get_input_node_shared_ptr(i);
if (!ov::is_type<ov::op::v0::Squeeze>(squeeze_node) && !ov::is_type<ov::op::v1::Reshape>(squeeze_node))
return false;
nodes_to_delete.push_back(squeeze);
auto split_to_check = std::dynamic_pointer_cast<ov::op::v1::Split>(squeeze->get_input_node_shared_ptr(0));
auto split_to_check =
std::dynamic_pointer_cast<ov::op::v1::Split>(squeeze_node->get_input_node_shared_ptr(0));
if (!split_to_check)
return false;
std::vector<int64_t> squeeze_axes_vec;
if (squeeze->get_input_size() < 2) {
const auto& shape = squeeze->get_input_partial_shape(0);
if (shape.is_dynamic()) {
return false;
}
for (size_t i = 0; i < shape.size(); i++) {
if (shape[i].get_length() == 1)
squeeze_axes_vec.push_back(static_cast<int64_t>(i));
}
} else {
auto squeeze_axes =
std::dynamic_pointer_cast<ov::op::v0::Constant>(squeeze->get_input_node_shared_ptr(1));
if (!squeeze_axes)
return false;
squeeze_axes_vec = squeeze_axes->cast_vector<int64_t>();
}
if (squeeze_axes_vec.size() != 1)
return false;
if (i == 0) {
axis_value = squeeze_axes_vec[0];
nodes_to_delete.push_back(split_to_check);
split = split_to_check;
} else if (axis_value != squeeze_axes_vec[0] || split_to_check != split) {
auto split_axis_node =
std::dynamic_pointer_cast<ov::op::v0::Constant>(split->get_input_node_shared_ptr(1));
if (!split_axis_node)
return false;
auto axis_vec = split_axis_node->cast_vector<int64_t>();
if (axis_vec.size() != 1)
return false;
split_axis = axis_vec[0];
if (split_axis < 0) {
auto rank = split->get_output_partial_shape(0).rank();
if (rank.is_dynamic())
return false;
split_axis += rank.get_length();
}
} else if (split_to_check != split) {
return false;
}
auto split_output = squeeze->input_value(0);
if (!is_axis_squeezed_by_node(squeeze_node, split_axis, use_shapes)) {
return false;
}
nodes_to_delete.push_back(squeeze_node);
auto split_output = squeeze_node->input_value(0);
if (split_output.get_target_inputs().size() != 1 || split_output.get_index() != i)
return false;
}
if (split->get_num_splits() != concat_inputs.size())
return false;
auto split_axis = std::dynamic_pointer_cast<ov::op::v0::Constant>(split->input_value(1).get_node_shared_ptr());
if (!split_axis)
return false;
auto axis_vec = split_axis->cast_vector<int64_t>();
if (axis_vec.size() != 1 || axis_value != axis_vec[0])
if (split->get_num_splits() != concat->get_input_size())
return false;
auto input = split->input_value(0);
@ -102,8 +89,8 @@ ov::pass::SplitSqueezeConcatFusion::SplitSqueezeConcatFusion() {
return false;
std::vector<int64_t> order(rank.get_length());
std::iota(order.begin(), order.end(), 0);
order.erase(order.begin() + axis_value);
order.insert(order.begin() + concat_axis, axis_value);
order.erase(order.begin() + split_axis);
order.insert(order.begin() + concat_axis, split_axis);
auto transpose_order = ov::op::v0::Constant::create(element::i64, {(size_t)rank.get_length()}, order);
auto transpose = register_new_node<ov::op::v1::Transpose>(input, transpose_order);
@ -120,3 +107,67 @@ ov::pass::SplitSqueezeConcatFusion::SplitSqueezeConcatFusion() {
auto m = std::make_shared<ov::pass::pattern::Matcher>(concat_pattern, matcher_name);
register_matcher(m, callback);
}
bool is_axis_squeezed_by_node(const std::shared_ptr<ov::Node>& squeeze_node, int64_t axis, bool use_shapes) {
const auto& input_shape = squeeze_node->get_input_partial_shape(0);
const auto& output_shape = squeeze_node->get_output_partial_shape(0);
if (input_shape.rank().is_dynamic() || output_shape.rank().is_dynamic())
return false;
auto input_rank = input_shape.rank().get_length();
auto output_rank = output_shape.rank().get_length();
// check if output_rank == input_rank - 1
// to make sure the node actually squeezes a dimension
if (input_rank != output_rank + 1)
return false;
// check if squeezed dimension equals to 1
if (input_shape[axis].is_dynamic() || input_shape[axis] != 1)
return false;
if (ov::is_type<ov::op::v1::Reshape>(squeeze_node)) {
if (!use_shapes)
return false;
// clang-format off
// check if the dimensions surrounding squeezed axis match
// function returns false if input_shape[:axis] != output_shape[:axis] or input_shape[(axis + 1):] != output_shape[axis:]
// clang-format on
if (input_shape.is_dynamic() || output_shape.is_dynamic())
return false;
if (!std::equal(input_shape.begin(), input_shape.begin() + axis, output_shape.begin()))
return false;
if (!std::equal(input_shape.begin() + axis + 1, input_shape.end(), output_shape.begin() + axis))
return false;
} else {
if (squeeze_node->get_input_size() == 1) {
// The case when Squeeze has only one input so every dimension == 1 is squeezed
if (input_shape.is_dynamic())
return false;
size_t num_squeezed_axes = 0;
for (size_t i = 0; i < input_shape.size(); i++) {
if (input_shape[i].get_length() == 1) {
num_squeezed_axes++;
if (num_squeezed_axes > 1)
return false;
if (static_cast<int64_t>(i) != axis)
return false;
}
}
} else {
// The second Squeeze input has explicit axes
auto constant = ov::as_type_ptr<ov::op::v0::Constant>(squeeze_node->get_input_node_shared_ptr(1));
if (!constant)
return false;
if (ov::shape_size(constant->get_shape()) != 1)
return false;
auto squeezed_axis = constant->cast_vector<int64_t>()[0];
squeezed_axis = squeezed_axis < 0 ? squeezed_axis + input_rank : squeezed_axis;
if (axis != squeezed_axis)
return false;
}
}
return true;
}

View File

@ -37,7 +37,7 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusion) {
model = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>();
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>(false);
}
{
@ -69,7 +69,7 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusionSqueezeWithoutAxesInput) {
model = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>();
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>(false);
}
{
@ -103,7 +103,7 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeCaseNotAllSplitOutp
model = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
model_ref = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>();
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>(false);
}
{
@ -144,7 +144,7 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeCaseSplitOutputsGoI
model = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
model_ref = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>();
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>(false);
}
{
@ -185,7 +185,7 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeCaseSplitAxisDiffer
model = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
model_ref = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>();
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>(false);
}
{
@ -222,6 +222,72 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeSqueezeWithoutAxesI
model = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>();
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>(false);
}
}
struct SplitReshapeConcatFusionParam {
int num_splits;
int split_axis;
Shape input_shape;
std::vector<int> reshaped_shape;
int concat_axis;
std::vector<int> transpose_order;
bool can_fuse;
};
class SplitReshapeConcatFusion : public TransformationTestsF,
public testing::WithParamInterface<SplitReshapeConcatFusionParam> {};
TEST_P(SplitReshapeConcatFusion, SplitSqueezeConcatFusion) {
auto params = GetParam();
ASSERT_EQ(0, params.input_shape[params.split_axis] % params.num_splits);
{
auto input = std::make_shared<opset7::Parameter>(element::f32, params.input_shape);
auto split_axis_node = opset7::Constant::create(element::i64, Shape{}, {params.split_axis});
auto split = std::make_shared<opset7::Split>(input, split_axis_node, params.num_splits);
OutputVector squeeze_vec;
squeeze_vec.reserve(params.num_splits);
auto reshaped_shape_node =
opset7::Constant::create(element::i32, Shape{params.reshaped_shape.size()}, params.reshaped_shape);
for (int i = 0; i < params.num_splits; i++) {
squeeze_vec.push_back(std::make_shared<opset7::Reshape>(split->output(i), reshaped_shape_node, true));
}
auto concat = std::make_shared<opset7::Concat>(squeeze_vec, params.concat_axis);
model = std::make_shared<ov::Model>(NodeVector{concat}, ParameterVector{input});
manager.register_pass<ov::pass::SplitSqueezeConcatFusion>(true);
}
if (!params.can_fuse) {
model_ref = model->clone();
} else {
auto input = std::make_shared<opset7::Parameter>(element::f32, params.input_shape);
auto transpose_order_node =
opset7::Constant::create(element::i64, Shape{params.transpose_order.size()}, params.transpose_order);
auto transpose = std::make_shared<opset7::Transpose>(input, transpose_order_node);
auto reshape_shape = params.input_shape;
reshape_shape.erase(reshape_shape.begin() + params.split_axis);
reshape_shape[params.concat_axis] *= params.num_splits;
auto reshape_shape_node = opset7::Constant::create(element::i64, Shape{reshape_shape.size()}, reshape_shape);
auto reshape = std::make_shared<opset7::Reshape>(transpose, reshape_shape_node, false);
model_ref = std::make_shared<ov::Model>(NodeVector{reshape}, ParameterVector{input});
}
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
static std::vector<SplitReshapeConcatFusionParam> split_reshape_concat_fusion_params{
{4, 2, Shape{3, 1, 4, 1, 5}, {3, 1, 1, 5}, 1, {0, 2, 1, 3, 4}, true},
{4, 0, Shape{4, 6, 5}, {6, 5}, 1, {1, 0, 2}, true},
{5, 2, Shape{4, 6, 5}, {4, 6}, 0, {2, 0, 1}, true},
{2, 2, Shape{3, 1, 4, 5}, {3, 2, 5}, 1, {0, 2, 1, 3}, false},
{2, 1, Shape{3, 2, 3, 4, 5}, {3, 3, 5, 4}, 1, {0, 2, 1, 3, 4}, false},
{4, 2, Shape{3, 1, 4, 1, 5}, {3, 1, 5}, 1, {0, 2, 1, 3, 4}, false},
};
INSTANTIATE_TEST_SUITE_P(TransformationTests,
SplitReshapeConcatFusion,
testing::ValuesIn(split_reshape_concat_fusion_params));