From 2071f728b7ab7154c676beff04661e05c265caf6 Mon Sep 17 00:00:00 2001 From: Mateusz Tabaka Date: Tue, 31 Oct 2023 14:05:21 +0100 Subject: [PATCH] Handle Reshape in SplitSqueezeConcatFusion (#20345) * Handle Reshape in SplitSqueezeConcatFusion Ticket: CVS-122455 * move check for squeeze/reshape * add some comments * review comments * add use_shapes flag to SplitSqueezeConcatFusion --- .../split_squeeze_concat_fusion.hpp | 2 +- .../moc_transformations.cpp | 2 +- .../split_squeeze_concat_fusion.cpp | 145 ++++++++++++------ .../split_squeeze_concat_fusion_test.cpp | 78 +++++++++- 4 files changed, 172 insertions(+), 55 deletions(-) diff --git a/src/common/transformations/include/transformations/common_optimizations/split_squeeze_concat_fusion.hpp b/src/common/transformations/include/transformations/common_optimizations/split_squeeze_concat_fusion.hpp index 733e6c66a5f..28f94637523 100644 --- a/src/common/transformations/include/transformations/common_optimizations/split_squeeze_concat_fusion.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/split_squeeze_concat_fusion.hpp @@ -27,5 +27,5 @@ class TRANSFORMATIONS_API SplitSqueezeConcatFusion; class ov::pass::SplitSqueezeConcatFusion : public ov::pass::MatcherPass { public: OPENVINO_RTTI("SplitSqueezeConcatFusion", "0"); - SplitSqueezeConcatFusion(); + SplitSqueezeConcatFusion(bool use_shapes); }; diff --git a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp index 9a3446f2386..5c768be324e 100644 --- a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp @@ -170,7 +170,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr // SplitSqueezeConcatFusion should work in same GraphRewrite as TransposesSinking, // because it replaces pattern that may contain Transposes which must be optimized before // the transformation and it also inserts Transpose that can be optimized by TransposeSinking - ADD_MATCHER(transpose_sinking, SplitSqueezeConcatFusion) + ADD_MATCHER(transpose_sinking, SplitSqueezeConcatFusion, m_use_shapes) REGISTER_PASS(manager, TransposeMatMul) diff --git a/src/common/transformations/src/transformations/common_optimizations/split_squeeze_concat_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/split_squeeze_concat_fusion.cpp index b4bc2567e77..0baac07d14e 100644 --- a/src/common/transformations/src/transformations/common_optimizations/split_squeeze_concat_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/split_squeeze_concat_fusion.cpp @@ -18,7 +18,9 @@ #include "openvino/op/transpose.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" -ov::pass::SplitSqueezeConcatFusion::SplitSqueezeConcatFusion() { +static bool is_axis_squeezed_by_node(const std::shared_ptr& squeeze_node, int64_t axis, bool use_shapes); + +ov::pass::SplitSqueezeConcatFusion::SplitSqueezeConcatFusion(bool use_shapes) { MATCHER_SCOPE(SplitSqueezeConcatFusion); // Detect only concat, because we don't know how many inputs will go into concat auto concat_pattern = ov::pass::pattern::wrap_type(); @@ -32,66 +34,51 @@ ov::pass::SplitSqueezeConcatFusion::SplitSqueezeConcatFusion() { NodeVector nodes_to_delete{concat}; - int64_t axis_value = 0; std::shared_ptr split; + int64_t split_axis = 0; - const auto& concat_inputs = concat->input_values(); - if (concat_inputs.empty()) - return false; - for (size_t i = 0; i < concat_inputs.size(); i++) { - auto squeeze = std::dynamic_pointer_cast(concat_inputs[i].get_node_shared_ptr()); - if (!squeeze) + for (size_t i = 0; i < concat->get_input_size(); i++) { + auto squeeze_node = concat->get_input_node_shared_ptr(i); + if (!ov::is_type(squeeze_node) && !ov::is_type(squeeze_node)) return false; - - nodes_to_delete.push_back(squeeze); - - auto split_to_check = std::dynamic_pointer_cast(squeeze->get_input_node_shared_ptr(0)); + auto split_to_check = + std::dynamic_pointer_cast(squeeze_node->get_input_node_shared_ptr(0)); if (!split_to_check) return false; - std::vector squeeze_axes_vec; - if (squeeze->get_input_size() < 2) { - const auto& shape = squeeze->get_input_partial_shape(0); - if (shape.is_dynamic()) { - return false; - } - for (size_t i = 0; i < shape.size(); i++) { - if (shape[i].get_length() == 1) - squeeze_axes_vec.push_back(static_cast(i)); - } - - } else { - auto squeeze_axes = - std::dynamic_pointer_cast(squeeze->get_input_node_shared_ptr(1)); - if (!squeeze_axes) - return false; - squeeze_axes_vec = squeeze_axes->cast_vector(); - } - - if (squeeze_axes_vec.size() != 1) - return false; if (i == 0) { - axis_value = squeeze_axes_vec[0]; nodes_to_delete.push_back(split_to_check); split = split_to_check; - } else if (axis_value != squeeze_axes_vec[0] || split_to_check != split) { + auto split_axis_node = + std::dynamic_pointer_cast(split->get_input_node_shared_ptr(1)); + if (!split_axis_node) + return false; + auto axis_vec = split_axis_node->cast_vector(); + if (axis_vec.size() != 1) + return false; + split_axis = axis_vec[0]; + if (split_axis < 0) { + auto rank = split->get_output_partial_shape(0).rank(); + if (rank.is_dynamic()) + return false; + split_axis += rank.get_length(); + } + } else if (split_to_check != split) { return false; } - auto split_output = squeeze->input_value(0); + if (!is_axis_squeezed_by_node(squeeze_node, split_axis, use_shapes)) { + return false; + } + + nodes_to_delete.push_back(squeeze_node); + + auto split_output = squeeze_node->input_value(0); if (split_output.get_target_inputs().size() != 1 || split_output.get_index() != i) return false; } - if (split->get_num_splits() != concat_inputs.size()) - return false; - - auto split_axis = std::dynamic_pointer_cast(split->input_value(1).get_node_shared_ptr()); - if (!split_axis) - return false; - - auto axis_vec = split_axis->cast_vector(); - if (axis_vec.size() != 1 || axis_value != axis_vec[0]) + if (split->get_num_splits() != concat->get_input_size()) return false; auto input = split->input_value(0); @@ -102,8 +89,8 @@ ov::pass::SplitSqueezeConcatFusion::SplitSqueezeConcatFusion() { return false; std::vector order(rank.get_length()); std::iota(order.begin(), order.end(), 0); - order.erase(order.begin() + axis_value); - order.insert(order.begin() + concat_axis, axis_value); + order.erase(order.begin() + split_axis); + order.insert(order.begin() + concat_axis, split_axis); auto transpose_order = ov::op::v0::Constant::create(element::i64, {(size_t)rank.get_length()}, order); auto transpose = register_new_node(input, transpose_order); @@ -120,3 +107,67 @@ ov::pass::SplitSqueezeConcatFusion::SplitSqueezeConcatFusion() { auto m = std::make_shared(concat_pattern, matcher_name); register_matcher(m, callback); } + +bool is_axis_squeezed_by_node(const std::shared_ptr& squeeze_node, int64_t axis, bool use_shapes) { + const auto& input_shape = squeeze_node->get_input_partial_shape(0); + const auto& output_shape = squeeze_node->get_output_partial_shape(0); + if (input_shape.rank().is_dynamic() || output_shape.rank().is_dynamic()) + return false; + + auto input_rank = input_shape.rank().get_length(); + auto output_rank = output_shape.rank().get_length(); + // check if output_rank == input_rank - 1 + // to make sure the node actually squeezes a dimension + if (input_rank != output_rank + 1) + return false; + + // check if squeezed dimension equals to 1 + if (input_shape[axis].is_dynamic() || input_shape[axis] != 1) + return false; + + if (ov::is_type(squeeze_node)) { + if (!use_shapes) + return false; + // clang-format off + // check if the dimensions surrounding squeezed axis match + // function returns false if input_shape[:axis] != output_shape[:axis] or input_shape[(axis + 1):] != output_shape[axis:] + // clang-format on + if (input_shape.is_dynamic() || output_shape.is_dynamic()) + return false; + + if (!std::equal(input_shape.begin(), input_shape.begin() + axis, output_shape.begin())) + return false; + + if (!std::equal(input_shape.begin() + axis + 1, input_shape.end(), output_shape.begin() + axis)) + return false; + } else { + if (squeeze_node->get_input_size() == 1) { + // The case when Squeeze has only one input so every dimension == 1 is squeezed + if (input_shape.is_dynamic()) + return false; + size_t num_squeezed_axes = 0; + for (size_t i = 0; i < input_shape.size(); i++) { + if (input_shape[i].get_length() == 1) { + num_squeezed_axes++; + if (num_squeezed_axes > 1) + return false; + if (static_cast(i) != axis) + return false; + } + } + } else { + // The second Squeeze input has explicit axes + auto constant = ov::as_type_ptr(squeeze_node->get_input_node_shared_ptr(1)); + if (!constant) + return false; + if (ov::shape_size(constant->get_shape()) != 1) + return false; + auto squeezed_axis = constant->cast_vector()[0]; + squeezed_axis = squeezed_axis < 0 ? squeezed_axis + input_rank : squeezed_axis; + if (axis != squeezed_axis) + return false; + } + } + + return true; +} diff --git a/src/common/transformations/tests/common_optimizations/split_squeeze_concat_fusion_test.cpp b/src/common/transformations/tests/common_optimizations/split_squeeze_concat_fusion_test.cpp index 8d72aba9b0c..3bea132c205 100644 --- a/src/common/transformations/tests/common_optimizations/split_squeeze_concat_fusion_test.cpp +++ b/src/common/transformations/tests/common_optimizations/split_squeeze_concat_fusion_test.cpp @@ -37,7 +37,7 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusion) { model = std::make_shared(NodeVector{concat}, ParameterVector{input}); - manager.register_pass(); + manager.register_pass(false); } { @@ -69,7 +69,7 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusionSqueezeWithoutAxesInput) { model = std::make_shared(NodeVector{concat}, ParameterVector{input}); - manager.register_pass(); + manager.register_pass(false); } { @@ -103,7 +103,7 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeCaseNotAllSplitOutp model = std::make_shared(NodeVector{concat}, ParameterVector{input}); model_ref = std::make_shared(NodeVector{concat}, ParameterVector{input}); - manager.register_pass(); + manager.register_pass(false); } { @@ -144,7 +144,7 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeCaseSplitOutputsGoI model = std::make_shared(NodeVector{concat}, ParameterVector{input}); model_ref = std::make_shared(NodeVector{concat}, ParameterVector{input}); - manager.register_pass(); + manager.register_pass(false); } { @@ -185,7 +185,7 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeCaseSplitAxisDiffer model = std::make_shared(NodeVector{concat}, ParameterVector{input}); model_ref = std::make_shared(NodeVector{concat}, ParameterVector{input}); - manager.register_pass(); + manager.register_pass(false); } { @@ -222,6 +222,72 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeSqueezeWithoutAxesI model = std::make_shared(NodeVector{concat}, ParameterVector{input}); - manager.register_pass(); + manager.register_pass(false); } } + +struct SplitReshapeConcatFusionParam { + int num_splits; + int split_axis; + Shape input_shape; + std::vector reshaped_shape; + int concat_axis; + std::vector transpose_order; + bool can_fuse; +}; + +class SplitReshapeConcatFusion : public TransformationTestsF, + public testing::WithParamInterface {}; + +TEST_P(SplitReshapeConcatFusion, SplitSqueezeConcatFusion) { + auto params = GetParam(); + ASSERT_EQ(0, params.input_shape[params.split_axis] % params.num_splits); + + { + auto input = std::make_shared(element::f32, params.input_shape); + auto split_axis_node = opset7::Constant::create(element::i64, Shape{}, {params.split_axis}); + auto split = std::make_shared(input, split_axis_node, params.num_splits); + OutputVector squeeze_vec; + squeeze_vec.reserve(params.num_splits); + auto reshaped_shape_node = + opset7::Constant::create(element::i32, Shape{params.reshaped_shape.size()}, params.reshaped_shape); + for (int i = 0; i < params.num_splits; i++) { + squeeze_vec.push_back(std::make_shared(split->output(i), reshaped_shape_node, true)); + } + auto concat = std::make_shared(squeeze_vec, params.concat_axis); + model = std::make_shared(NodeVector{concat}, ParameterVector{input}); + manager.register_pass(true); + } + + if (!params.can_fuse) { + model_ref = model->clone(); + } else { + auto input = std::make_shared(element::f32, params.input_shape); + auto transpose_order_node = + opset7::Constant::create(element::i64, Shape{params.transpose_order.size()}, params.transpose_order); + auto transpose = std::make_shared(input, transpose_order_node); + auto reshape_shape = params.input_shape; + reshape_shape.erase(reshape_shape.begin() + params.split_axis); + reshape_shape[params.concat_axis] *= params.num_splits; + auto reshape_shape_node = opset7::Constant::create(element::i64, Shape{reshape_shape.size()}, reshape_shape); + auto reshape = std::make_shared(transpose, reshape_shape_node, false); + + model_ref = std::make_shared(NodeVector{reshape}, ParameterVector{input}); + } + + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); +} + +static std::vector split_reshape_concat_fusion_params{ + {4, 2, Shape{3, 1, 4, 1, 5}, {3, 1, 1, 5}, 1, {0, 2, 1, 3, 4}, true}, + {4, 0, Shape{4, 6, 5}, {6, 5}, 1, {1, 0, 2}, true}, + {5, 2, Shape{4, 6, 5}, {4, 6}, 0, {2, 0, 1}, true}, + {2, 2, Shape{3, 1, 4, 5}, {3, 2, 5}, 1, {0, 2, 1, 3}, false}, + {2, 1, Shape{3, 2, 3, 4, 5}, {3, 3, 5, 4}, 1, {0, 2, 1, 3, 4}, false}, + {4, 2, Shape{3, 1, 4, 1, 5}, {3, 1, 5}, 1, {0, 2, 1, 3, 4}, false}, +}; + +INSTANTIATE_TEST_SUITE_P(TransformationTests, + SplitReshapeConcatFusion, + testing::ValuesIn(split_reshape_concat_fusion_params));