Set squeeze output shape to scalar if 0 <= input_shape[0] <= 1 (#14293)
* Set squeeze output shape to scalar if 0 <= input_shape[0] <= 1 * add squeeze type_prop test case * Update src/core/shape_inference/include/squeeze_shape_inference.hpp Co-authored-by: Katarzyna Mitrus <katarzyna.mitrus@intel.com> * Update src/core/shape_inference/include/squeeze_shape_inference.hpp Co-authored-by: Katarzyna Mitrus <katarzyna.mitrus@intel.com>
This commit is contained in:
parent
22d7bc70d9
commit
3c31488dfe
@ -87,8 +87,13 @@ void shape_infer(const Squeeze* op,
|
||||
std::copy_if(arg_shape.cbegin(), arg_shape.cend(), back_inserter(out_dims), not_squeezable_at_axis);
|
||||
}
|
||||
// When arg shape has got static rank but shape is dynamic and output shape dimensions is empty
|
||||
// make dynamic output.
|
||||
output_shape = arg_shape.is_dynamic() && out_dims.empty() ? PartialShape::dynamic() : T(out_dims);
|
||||
// make dynamic output except the case of the rank of arg shape is 1 and 0 <= arg_shape[0] <= 1.
|
||||
if (arg_shape.is_dynamic() && out_dims.empty()) {
|
||||
output_shape = arg_shape.rank().get_length() == 1 && arg_shape[0].get_max_length() <= 1
|
||||
? T{} // Output shape is a scalar
|
||||
: PartialShape::dynamic();
|
||||
} else
|
||||
output_shape = T(out_dims);
|
||||
} else {
|
||||
output_shape = PartialShape::dynamic();
|
||||
}
|
||||
|
@ -60,7 +60,7 @@ protected:
|
||||
std::set<int64_t> axes_to_remove;
|
||||
if (axes.empty()) {
|
||||
for (auto dim = p_shape.begin(); dim != p_shape.end(); ++dim) {
|
||||
if (*dim == 1 || exp_shape.rank().is_dynamic()) {
|
||||
if (dim->get_max_length() == 1 || exp_shape.rank().is_dynamic()) {
|
||||
axes_to_remove.insert(std::distance(p_shape.begin(), dim));
|
||||
}
|
||||
}
|
||||
@ -92,6 +92,7 @@ protected:
|
||||
|
||||
const auto static_partial_shapes_test_values =
|
||||
Values(std::make_tuple(PartialShape{1}, std::vector<int64_t>{0}, PartialShape{}),
|
||||
std::make_tuple(PartialShape{}, std::vector<int64_t>{0}, PartialShape{}),
|
||||
std::make_tuple(PartialShape{1, 2}, std::vector<int64_t>{0}, PartialShape{2}),
|
||||
std::make_tuple(PartialShape{1, 2}, std::vector<int64_t>{-2}, PartialShape{2}),
|
||||
std::make_tuple(PartialShape{1, 2, 1}, std::vector<int64_t>{0}, PartialShape{2, 1}),
|
||||
@ -105,6 +106,7 @@ const auto empty_axes_test_values =
|
||||
std::vector<int64_t>{},
|
||||
PartialShape{Dimension(2, 5), Dimension(3, 4), 6}),
|
||||
std::make_tuple(PartialShape::dynamic(6), std::vector<int64_t>{}, PartialShape::dynamic()),
|
||||
std::make_tuple(PartialShape{Dimension(0, 1)}, std::vector<int64_t>{}, PartialShape{}),
|
||||
std::make_tuple(PartialShape{Dimension::dynamic(), 1, Dimension::dynamic()},
|
||||
std::vector<int64_t>{},
|
||||
PartialShape::dynamic()),
|
||||
|
Loading…
Reference in New Issue
Block a user