From 3c31488dfe96ddc06baa9b988577f6e112d03398 Mon Sep 17 00:00:00 2001 From: "mei, yang" Date: Wed, 14 Dec 2022 16:13:05 +0800 Subject: [PATCH] Set squeeze output shape to scalar if 0 <= input_shape[0] <= 1 (#14293) * Set squeeze output shape to scalar if 0 <= input_shape[0] <= 1 * add squeeze type_prop test case * Update src/core/shape_inference/include/squeeze_shape_inference.hpp Co-authored-by: Katarzyna Mitrus * Update src/core/shape_inference/include/squeeze_shape_inference.hpp Co-authored-by: Katarzyna Mitrus --- .../shape_inference/include/squeeze_shape_inference.hpp | 9 +++++++-- src/core/tests/type_prop/squeeze.cpp | 4 +++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/core/shape_inference/include/squeeze_shape_inference.hpp b/src/core/shape_inference/include/squeeze_shape_inference.hpp index b240630f81d..27a1737659d 100644 --- a/src/core/shape_inference/include/squeeze_shape_inference.hpp +++ b/src/core/shape_inference/include/squeeze_shape_inference.hpp @@ -87,8 +87,13 @@ void shape_infer(const Squeeze* op, std::copy_if(arg_shape.cbegin(), arg_shape.cend(), back_inserter(out_dims), not_squeezable_at_axis); } // When arg shape has got static rank but shape is dynamic and output shape dimensions is empty - // make dynamic output. - output_shape = arg_shape.is_dynamic() && out_dims.empty() ? PartialShape::dynamic() : T(out_dims); + // make dynamic output except the case of the rank of arg shape is 1 and 0 <= arg_shape[0] <= 1. + if (arg_shape.is_dynamic() && out_dims.empty()) { + output_shape = arg_shape.rank().get_length() == 1 && arg_shape[0].get_max_length() <= 1 + ? T{} // Output shape is a scalar + : PartialShape::dynamic(); + } else + output_shape = T(out_dims); } else { output_shape = PartialShape::dynamic(); } diff --git a/src/core/tests/type_prop/squeeze.cpp b/src/core/tests/type_prop/squeeze.cpp index dfb93cfb853..c2b3087fa27 100644 --- a/src/core/tests/type_prop/squeeze.cpp +++ b/src/core/tests/type_prop/squeeze.cpp @@ -60,7 +60,7 @@ protected: std::set axes_to_remove; if (axes.empty()) { for (auto dim = p_shape.begin(); dim != p_shape.end(); ++dim) { - if (*dim == 1 || exp_shape.rank().is_dynamic()) { + if (dim->get_max_length() == 1 || exp_shape.rank().is_dynamic()) { axes_to_remove.insert(std::distance(p_shape.begin(), dim)); } } @@ -92,6 +92,7 @@ protected: const auto static_partial_shapes_test_values = Values(std::make_tuple(PartialShape{1}, std::vector{0}, PartialShape{}), + std::make_tuple(PartialShape{}, std::vector{0}, PartialShape{}), std::make_tuple(PartialShape{1, 2}, std::vector{0}, PartialShape{2}), std::make_tuple(PartialShape{1, 2}, std::vector{-2}, PartialShape{2}), std::make_tuple(PartialShape{1, 2, 1}, std::vector{0}, PartialShape{2, 1}), @@ -105,6 +106,7 @@ const auto empty_axes_test_values = std::vector{}, PartialShape{Dimension(2, 5), Dimension(3, 4), 6}), std::make_tuple(PartialShape::dynamic(6), std::vector{}, PartialShape::dynamic()), + std::make_tuple(PartialShape{Dimension(0, 1)}, std::vector{}, PartialShape{}), std::make_tuple(PartialShape{Dimension::dynamic(), 1, Dimension::dynamic()}, std::vector{}, PartialShape::dynamic()),