From bd0c156a70693ac6f7b10fa64d10d8d6aa532478 Mon Sep 17 00:00:00 2001 From: Mateusz Tabaka Date: Mon, 4 Sep 2023 13:58:53 +0200 Subject: [PATCH] PullReshapeThroughReduce - skip transformation if Reshape doesn't unsqueeze input (#19477) Ticket: CVS-118905 --- .../common_optimizations/pull_through_reduce.cpp | 3 +++ .../tests/common_optimizations/pull_through_reduce_test.cpp | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/src/common/transformations/src/transformations/common_optimizations/pull_through_reduce.cpp b/src/common/transformations/src/transformations/common_optimizations/pull_through_reduce.cpp index 8f16025a1ac..0ceac5ae44a 100644 --- a/src/common/transformations/src/transformations/common_optimizations/pull_through_reduce.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/pull_through_reduce.cpp @@ -55,6 +55,9 @@ const std::vector adjust_axes(const std::vector& axes_to_align // - Reshape(input_shape={5,10,15}, target_shape={5,10,1,15}), 2 axis is returned std::vector try_get_unsqueeze_axes_from_reshape(const ov::Shape& target_shape, const ov::Shape& input_shape) { std::vector result; + if (target_shape.size() <= input_shape.size()) { + return result; + } if (input_shape.size() == 0) { // scalar case - can be reshaped only to [1,..,1] shape result.resize(target_shape.size(), 0); std::iota(std::begin(result), std::end(result), 0); diff --git a/src/common/transformations/tests/common_optimizations/pull_through_reduce_test.cpp b/src/common/transformations/tests/common_optimizations/pull_through_reduce_test.cpp index a516e966090..fb1689f6bbc 100644 --- a/src/common/transformations/tests/common_optimizations/pull_through_reduce_test.cpp +++ b/src/common/transformations/tests/common_optimizations/pull_through_reduce_test.cpp @@ -360,6 +360,11 @@ TEST_F(TransformationTestsF, PullReshapeThroughReduceSkipIfTheSameAxesScalarCase manager.register_pass(); } +TEST_F(TransformationTestsF, PullReshapeThroughReduceSkipIfReshapeDoesntUnsqueeze) { + model = generate_reshape_model(element::f32, {1, 100, 1}, {1, 1, 100}, {2}); + manager.register_pass(); +} + TEST_F(TransformationTestsF, PullReshapeThroughReduceSkipIfNonConstAxes) { const auto input = std::make_shared(element::f32, PartialShape{5, 10, 15}); const auto target_shape = Constant::create(element::i64, Shape{4}, {1, 5, 10, 15});