diff --git a/src/core/include/openvino/core/partial_shape.hpp b/src/core/include/openvino/core/partial_shape.hpp index 78574e5d2e7..9a70b72867d 100644 --- a/src/core/include/openvino/core/partial_shape.hpp +++ b/src/core/include/openvino/core/partial_shape.hpp @@ -32,6 +32,7 @@ class OPENVINO_API PartialShape { using Dimensions = std::vector; public: + using value_type = Dimensions::value_type; using iterator = Dimensions::iterator; using const_iterator = Dimensions::const_iterator; using reverse_iterator = Dimensions::reverse_iterator; @@ -342,6 +343,13 @@ public: m_rank_is_static = true; m_shape_type = ShapeType::SHAPE_IS_UPDATED; } + /// \brief emplace element to the end of partial shape + template + void emplace_back(Args&&... args) { + m_dimensions.emplace_back(std::forward(args)...); + m_rank_is_static = true; + m_shape_type = ShapeType::SHAPE_IS_UPDATED; + } /// \brief String representation of PartialShape std::string to_string() const; diff --git a/src/core/shape_inference/include/squeeze_shape_inference.hpp b/src/core/shape_inference/include/squeeze_shape_inference.hpp index 27a1737659d..c33ebf74677 100644 --- a/src/core/shape_inference/include/squeeze_shape_inference.hpp +++ b/src/core/shape_inference/include/squeeze_shape_inference.hpp @@ -32,6 +32,7 @@ void shape_infer(const Squeeze* op, const auto number_of_inputs = input_shapes.size(); const auto& arg_shape = input_shapes[0]; + const auto& arg_rank = arg_shape.rank(); auto& output_shape = output_shapes[0]; std::unique_ptr> unique_axes; @@ -46,9 +47,8 @@ void shape_infer(const Squeeze* op, axes_shape.rank().get_length()); std::vector axes; - if (arg_shape.rank().is_static() && axes_shape.is_static() && - get_data_as_int64(1, op, axes, constant_data)) { - normalize_axes(op, arg_shape.rank().get_length(), axes); + if (arg_rank.is_static() && axes_shape.is_static() && get_data_as_int64(1, op, axes, constant_data)) { + normalize_axes(op, arg_rank.get_length(), axes); unique_axes.reset(new std::set(axes.cbegin(), axes.cend())); } } else { @@ -56,16 +56,18 @@ void shape_infer(const Squeeze* op, NODE_VALIDATION_CHECK(op, false); } - if (arg_shape.rank().is_static() && (unique_axes != nullptr)) { - std::vector out_dims; - out_dims.reserve(arg_shape.rank().get_length()); + if (arg_rank.is_static() && (unique_axes != nullptr)) { + output_shape.resize(0); if (unique_axes->empty()) { // According to specification, if only first input provided` or axes are empty // remove all dimensions equal to 1. - std::copy_if(arg_shape.cbegin(), arg_shape.cend(), back_inserter(out_dims), [](const DimType& dim) { - return !dim.compatible(1); - }); + std::copy_if(arg_shape.cbegin(), + arg_shape.cend(), + std::back_inserter(output_shape), + [](const DimType& dim) { + return !dim.compatible(1); + }); } else { int64_t idx = 0; auto rm_axis_iter = unique_axes->cbegin(); @@ -84,16 +86,17 @@ void shape_infer(const Squeeze* op, } }; - std::copy_if(arg_shape.cbegin(), arg_shape.cend(), back_inserter(out_dims), not_squeezable_at_axis); + std::copy_if(arg_shape.cbegin(), + arg_shape.cend(), + std::back_inserter(output_shape), + not_squeezable_at_axis); + } + // When arg shape has got static rank but shape is dynamic and output shape dimensions is empty (scalar) + // make dynamic output except the case when arg_shape is 1-D shape with 0 or 1 element then should be scalar. + if (arg_shape.is_dynamic() && (output_shape.size() == 0) && + !(arg_rank.get_length() == 1 && arg_shape[0].get_max_length() <= 1)) { + output_shape = PartialShape::dynamic(); } - // When arg shape has got static rank but shape is dynamic and output shape dimensions is empty - // make dynamic output except the case of the rank of arg shape is 1 and 0 <= arg_shape[0] <= 1. - if (arg_shape.is_dynamic() && out_dims.empty()) { - output_shape = arg_shape.rank().get_length() == 1 && arg_shape[0].get_max_length() <= 1 - ? T{} // Output shape is a scalar - : PartialShape::dynamic(); - } else - output_shape = T(out_dims); } else { output_shape = PartialShape::dynamic(); } diff --git a/src/core/tests/partial_shape.cpp b/src/core/tests/partial_shape.cpp index 5e15fe230c1..4689e833217 100644 --- a/src/core/tests/partial_shape.cpp +++ b/src/core/tests/partial_shape.cpp @@ -782,6 +782,22 @@ TEST(partial_shape, changed_dimension_by_reference) { ASSERT_TRUE(s.is_static()); } +TEST(partial_shape, emplace_back_new_dimension) { + PartialShape s{2, 3, Dimension::dynamic(), 5}; + + s.emplace_back(3, 5); + + ASSERT_EQ(s, PartialShape({2, 3, -1, 5, {3, 5}})); +} + +TEST(partial_shape, copy_with_back_inserter_iterator) { + PartialShape s{2, 3, Dimension::dynamic(), 5}, s_copy; + + std::copy(s.begin(), s.end(), std::back_inserter(s_copy)); + + ASSERT_EQ(s_copy, s); +} + TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_ok) { auto node = std::make_shared(element::f32, Shape{}); PartialShape data_shape{PartialShape::dynamic()};