From 441adcc1225ae751e8c553e69a793e5460c0b049 Mon Sep 17 00:00:00 2001 From: Mateusz Tabaka Date: Fri, 1 Sep 2023 21:35:17 +0200 Subject: [PATCH] SplitSqueezeConcatFusion - handle Squeeze nodes without second input (#19512) Ticket: CVS-119330 --- .../split_squeeze_concat_fusion.cpp | 26 ++++++--- .../split_squeeze_concat_fusion_test.cpp | 53 +++++++++++++++++++ 2 files changed, 73 insertions(+), 6 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/split_squeeze_concat_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/split_squeeze_concat_fusion.cpp index 358227c39ae..b4bc2567e77 100644 --- a/src/common/transformations/src/transformations/common_optimizations/split_squeeze_concat_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/split_squeeze_concat_fusion.cpp @@ -45,14 +45,28 @@ ov::pass::SplitSqueezeConcatFusion::SplitSqueezeConcatFusion() { nodes_to_delete.push_back(squeeze); - auto split_to_check = - std::dynamic_pointer_cast(squeeze->input_value(0).get_node_shared_ptr()); - auto squeeze_axes = - std::dynamic_pointer_cast(squeeze->input_value(1).get_node_shared_ptr()); - if (!squeeze_axes || !split_to_check) + auto split_to_check = std::dynamic_pointer_cast(squeeze->get_input_node_shared_ptr(0)); + if (!split_to_check) return false; + std::vector squeeze_axes_vec; + if (squeeze->get_input_size() < 2) { + const auto& shape = squeeze->get_input_partial_shape(0); + if (shape.is_dynamic()) { + return false; + } + for (size_t i = 0; i < shape.size(); i++) { + if (shape[i].get_length() == 1) + squeeze_axes_vec.push_back(static_cast(i)); + } + + } else { + auto squeeze_axes = + std::dynamic_pointer_cast(squeeze->get_input_node_shared_ptr(1)); + if (!squeeze_axes) + return false; + squeeze_axes_vec = squeeze_axes->cast_vector(); + } - auto squeeze_axes_vec = squeeze_axes->cast_vector(); if (squeeze_axes_vec.size() != 1) return false; diff --git a/src/common/transformations/tests/common_optimizations/split_squeeze_concat_fusion_test.cpp b/src/common/transformations/tests/common_optimizations/split_squeeze_concat_fusion_test.cpp index d5e23376c66..8d72aba9b0c 100644 --- a/src/common/transformations/tests/common_optimizations/split_squeeze_concat_fusion_test.cpp +++ b/src/common/transformations/tests/common_optimizations/split_squeeze_concat_fusion_test.cpp @@ -52,6 +52,38 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusion) { } } +TEST_F(TransformationTestsF, SplitSqueezeConcatFusionSqueezeWithoutAxesInput) { + size_t num_splits = 4; + + { + auto input = std::make_shared(element::f32, Shape{3, 2, num_splits, 640, 20, 2}); + auto split_axis = opset7::Constant::create(element::i64, Shape{}, {2}); + auto split = std::make_shared(input, split_axis, num_splits); + OutputVector squeeze_vec(num_splits); + + for (size_t i = 0; i < squeeze_vec.size(); i++) { + squeeze_vec[i] = std::make_shared(split->output(i)); + } + + auto concat = std::make_shared(squeeze_vec, 4); + + model = std::make_shared(NodeVector{concat}, ParameterVector{input}); + + manager.register_pass(); + } + + { + auto input = std::make_shared(element::f32, Shape{3, 2, num_splits, 640, 20, 2}); + auto transpose_order = opset7::Constant::create(element::i64, Shape{6}, {0, 1, 3, 4, 2, 5}); + auto transpose = std::make_shared(input, transpose_order); + auto reshape_shape = + opset7::Constant::create(element::i64, Shape{5}, {3, 2, 640, 20, 2 * (int64_t)num_splits}); + auto reshape = std::make_shared(transpose, reshape_shape, false); + + model_ref = std::make_shared(NodeVector{reshape}, ParameterVector{input}); + } +} + TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeCaseNotAllSplitOutputsGoToSqueeze) { size_t num_splits = 4; @@ -172,3 +204,24 @@ TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeCaseSplitAxisDiffer model_ref = std::make_shared(NodeVector{concat}, ParameterVector{input}); } } + +TEST_F(TransformationTestsF, SplitSqueezeConcatFusionNegativeSqueezeWithoutAxesInputMultipleUnitDimensions) { + size_t num_splits = 4; + + { + auto input = std::make_shared(element::f32, Shape{1, 2, num_splits, 640, 20, 2}); + auto split_axis = opset7::Constant::create(element::i64, Shape{}, {2}); + auto split = std::make_shared(input, split_axis, num_splits); + OutputVector squeeze_vec(num_splits); + + for (size_t i = 0; i < squeeze_vec.size(); i++) { + squeeze_vec[i] = std::make_shared(split->output(i)); + } + + auto concat = std::make_shared(squeeze_vec, 3); + + model = std::make_shared(NodeVector{concat}, ParameterVector{input}); + + manager.register_pass(); + } +}