Set squeeze output shape to scalar if 0 <= input_shape[0] <= 1 (#14293)

* Set squeeze output shape to scalar if 0 <= input_shape[0] <= 1

* add squeeze type_prop test case

* Update src/core/shape_inference/include/squeeze_shape_inference.hpp

Co-authored-by: Katarzyna Mitrus <katarzyna.mitrus@intel.com>

* Update src/core/shape_inference/include/squeeze_shape_inference.hpp

Co-authored-by: Katarzyna Mitrus <katarzyna.mitrus@intel.com>
This commit is contained in:
mei, yang 2022-12-14 16:13:05 +08:00 committed by GitHub
parent 22d7bc70d9
commit 3c31488dfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 3 deletions

View File

@ -87,8 +87,13 @@ void shape_infer(const Squeeze* op,
std::copy_if(arg_shape.cbegin(), arg_shape.cend(), back_inserter(out_dims), not_squeezable_at_axis);
}
// When arg shape has got static rank but shape is dynamic and output shape dimensions is empty
// make dynamic output.
output_shape = arg_shape.is_dynamic() && out_dims.empty() ? PartialShape::dynamic() : T(out_dims);
// make dynamic output except the case of the rank of arg shape is 1 and 0 <= arg_shape[0] <= 1.
if (arg_shape.is_dynamic() && out_dims.empty()) {
output_shape = arg_shape.rank().get_length() == 1 && arg_shape[0].get_max_length() <= 1
? T{} // Output shape is a scalar
: PartialShape::dynamic();
} else
output_shape = T(out_dims);
} else {
output_shape = PartialShape::dynamic();
}

View File

@ -60,7 +60,7 @@ protected:
std::set<int64_t> axes_to_remove;
if (axes.empty()) {
for (auto dim = p_shape.begin(); dim != p_shape.end(); ++dim) {
if (*dim == 1 || exp_shape.rank().is_dynamic()) {
if (dim->get_max_length() == 1 || exp_shape.rank().is_dynamic()) {
axes_to_remove.insert(std::distance(p_shape.begin(), dim));
}
}
@ -92,6 +92,7 @@ protected:
const auto static_partial_shapes_test_values =
Values(std::make_tuple(PartialShape{1}, std::vector<int64_t>{0}, PartialShape{}),
std::make_tuple(PartialShape{}, std::vector<int64_t>{0}, PartialShape{}),
std::make_tuple(PartialShape{1, 2}, std::vector<int64_t>{0}, PartialShape{2}),
std::make_tuple(PartialShape{1, 2}, std::vector<int64_t>{-2}, PartialShape{2}),
std::make_tuple(PartialShape{1, 2, 1}, std::vector<int64_t>{0}, PartialShape{2, 1}),
@ -105,6 +106,7 @@ const auto empty_axes_test_values =
std::vector<int64_t>{},
PartialShape{Dimension(2, 5), Dimension(3, 4), 6}),
std::make_tuple(PartialShape::dynamic(6), std::vector<int64_t>{}, PartialShape::dynamic()),
std::make_tuple(PartialShape{Dimension(0, 1)}, std::vector<int64_t>{}, PartialShape{}),
std::make_tuple(PartialShape{Dimension::dynamic(), 1, Dimension::dynamic()},
std::vector<int64_t>{},
PartialShape::dynamic()),