Extend partial shape to use back inserter and emplace (#14723)
* Extend partial shape interface add value_type - add emplace_back * Example to created dimension in output shape instead of using tmp vector * Add partial shape tests
This commit is contained in:
parent
fcd33063be
commit
5714fdfe6b
@ -32,6 +32,7 @@ class OPENVINO_API PartialShape {
|
||||
using Dimensions = std::vector<Dimension>;
|
||||
|
||||
public:
|
||||
using value_type = Dimensions::value_type;
|
||||
using iterator = Dimensions::iterator;
|
||||
using const_iterator = Dimensions::const_iterator;
|
||||
using reverse_iterator = Dimensions::reverse_iterator;
|
||||
@ -342,6 +343,13 @@ public:
|
||||
m_rank_is_static = true;
|
||||
m_shape_type = ShapeType::SHAPE_IS_UPDATED;
|
||||
}
|
||||
/// \brief emplace element to the end of partial shape
|
||||
template <class... Args>
|
||||
void emplace_back(Args&&... args) {
|
||||
m_dimensions.emplace_back(std::forward<Args>(args)...);
|
||||
m_rank_is_static = true;
|
||||
m_shape_type = ShapeType::SHAPE_IS_UPDATED;
|
||||
}
|
||||
|
||||
/// \brief String representation of PartialShape
|
||||
std::string to_string() const;
|
||||
|
@ -32,6 +32,7 @@ void shape_infer(const Squeeze* op,
|
||||
const auto number_of_inputs = input_shapes.size();
|
||||
|
||||
const auto& arg_shape = input_shapes[0];
|
||||
const auto& arg_rank = arg_shape.rank();
|
||||
auto& output_shape = output_shapes[0];
|
||||
|
||||
std::unique_ptr<std::set<int64_t>> unique_axes;
|
||||
@ -46,9 +47,8 @@ void shape_infer(const Squeeze* op,
|
||||
axes_shape.rank().get_length());
|
||||
|
||||
std::vector<int64_t> axes;
|
||||
if (arg_shape.rank().is_static() && axes_shape.is_static() &&
|
||||
get_data_as_int64<T>(1, op, axes, constant_data)) {
|
||||
normalize_axes(op, arg_shape.rank().get_length(), axes);
|
||||
if (arg_rank.is_static() && axes_shape.is_static() && get_data_as_int64<T>(1, op, axes, constant_data)) {
|
||||
normalize_axes(op, arg_rank.get_length(), axes);
|
||||
unique_axes.reset(new std::set<int64_t>(axes.cbegin(), axes.cend()));
|
||||
}
|
||||
} else {
|
||||
@ -56,16 +56,18 @@ void shape_infer(const Squeeze* op,
|
||||
NODE_VALIDATION_CHECK(op, false);
|
||||
}
|
||||
|
||||
if (arg_shape.rank().is_static() && (unique_axes != nullptr)) {
|
||||
std::vector<DimType> out_dims;
|
||||
out_dims.reserve(arg_shape.rank().get_length());
|
||||
if (arg_rank.is_static() && (unique_axes != nullptr)) {
|
||||
output_shape.resize(0);
|
||||
|
||||
if (unique_axes->empty()) {
|
||||
// According to specification, if only first input provided` or axes are empty
|
||||
// remove all dimensions equal to 1.
|
||||
std::copy_if(arg_shape.cbegin(), arg_shape.cend(), back_inserter(out_dims), [](const DimType& dim) {
|
||||
return !dim.compatible(1);
|
||||
});
|
||||
std::copy_if(arg_shape.cbegin(),
|
||||
arg_shape.cend(),
|
||||
std::back_inserter(output_shape),
|
||||
[](const DimType& dim) {
|
||||
return !dim.compatible(1);
|
||||
});
|
||||
} else {
|
||||
int64_t idx = 0;
|
||||
auto rm_axis_iter = unique_axes->cbegin();
|
||||
@ -84,16 +86,17 @@ void shape_infer(const Squeeze* op,
|
||||
}
|
||||
};
|
||||
|
||||
std::copy_if(arg_shape.cbegin(), arg_shape.cend(), back_inserter(out_dims), not_squeezable_at_axis);
|
||||
std::copy_if(arg_shape.cbegin(),
|
||||
arg_shape.cend(),
|
||||
std::back_inserter(output_shape),
|
||||
not_squeezable_at_axis);
|
||||
}
|
||||
// When arg shape has got static rank but shape is dynamic and output shape dimensions is empty (scalar)
|
||||
// make dynamic output except the case when arg_shape is 1-D shape with 0 or 1 element then should be scalar.
|
||||
if (arg_shape.is_dynamic() && (output_shape.size() == 0) &&
|
||||
!(arg_rank.get_length() == 1 && arg_shape[0].get_max_length() <= 1)) {
|
||||
output_shape = PartialShape::dynamic();
|
||||
}
|
||||
// When arg shape has got static rank but shape is dynamic and output shape dimensions is empty
|
||||
// make dynamic output except the case of the rank of arg shape is 1 and 0 <= arg_shape[0] <= 1.
|
||||
if (arg_shape.is_dynamic() && out_dims.empty()) {
|
||||
output_shape = arg_shape.rank().get_length() == 1 && arg_shape[0].get_max_length() <= 1
|
||||
? T{} // Output shape is a scalar
|
||||
: PartialShape::dynamic();
|
||||
} else
|
||||
output_shape = T(out_dims);
|
||||
} else {
|
||||
output_shape = PartialShape::dynamic();
|
||||
}
|
||||
|
@ -782,6 +782,22 @@ TEST(partial_shape, changed_dimension_by_reference) {
|
||||
ASSERT_TRUE(s.is_static());
|
||||
}
|
||||
|
||||
TEST(partial_shape, emplace_back_new_dimension) {
|
||||
PartialShape s{2, 3, Dimension::dynamic(), 5};
|
||||
|
||||
s.emplace_back(3, 5);
|
||||
|
||||
ASSERT_EQ(s, PartialShape({2, 3, -1, 5, {3, 5}}));
|
||||
}
|
||||
|
||||
TEST(partial_shape, copy_with_back_inserter_iterator) {
|
||||
PartialShape s{2, 3, Dimension::dynamic(), 5}, s_copy;
|
||||
|
||||
std::copy(s.begin(), s.end(), std::back_inserter(s_copy));
|
||||
|
||||
ASSERT_EQ(s_copy, s);
|
||||
}
|
||||
|
||||
TEST(partial_shape, infer_windowed_reduction_rank_dynamic_rank_dynamic_ok) {
|
||||
auto node = std::make_shared<op::Parameter>(element::f32, Shape{});
|
||||
PartialShape data_shape{PartialShape::dynamic()};
|
||||
|
Loading…
Reference in New Issue
Block a user