[ShapeInference] GatherND improvements review (#15416)

* GatherND shape infer base refactor

* Use accumulate for dims fusing

* Replace accumulate with for_each
This commit is contained in:
Katarzyna Mitrus 2023-02-02 11:44:13 +01:00 committed by GitHub
parent 15fa07eb5a
commit 5388a6f2af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -37,29 +37,31 @@ std::vector<TShape> gather_nd_base_shape_infer(const TOp* op, const std::vector<
NODE_VALIDATION_CHECK(
op,
static_cast<int64_t>(indices_tuple_length + op->get_batch_dims()) <= data_pshape.rank().get_length(),
cmp::le(indices_tuple_length + op->get_batch_dims(), data_pshape.rank().get_length()),
"Length of a tuple with indices must not exceed a rank of data tensor excluding batch dimensions.");
int64_t slice_length = data_pshape.rank().get_length() - indices_tuple_length - batch_dims;
int64_t output_indices_length = indices_pshape.rank().get_length() - batch_dims - 1;
auto output_rank = output_indices_length + slice_length;
const auto slice_length = data_pshape.size() - indices_tuple_length - batch_dims;
const auto output_indices_length = indices_pshape.size() - batch_dims - 1;
using DimType = typename TShape::value_type;
std::vector<DimType> output_shape(output_rank + batch_dims);
for (size_t dim = 0; dim < batch_dims; ++dim) {
std::vector<DimType> output_dims(batch_dims);
output_dims.reserve(batch_dims + output_indices_length + slice_length);
// Merge batch dimensions
for (size_t dim_idx = 0; dim_idx < batch_dims; ++dim_idx) {
NODE_VALIDATION_CHECK(op,
DimType::merge(output_shape[dim], data_pshape[dim], indices_pshape[dim]),
DimType::merge(output_dims[dim_idx], data_pshape[dim_idx], indices_pshape[dim_idx]),
"Batch dimensions of data and indices must be the same.");
}
for (int64_t dim = 0; dim < output_indices_length; ++dim) {
output_shape[batch_dims + dim] = indices_pshape[batch_dims + dim];
// Insert middle dimensions from the indices shape
for (auto dim_idx = batch_dims; dim_idx < indices_pshape.size() - 1; ++dim_idx) {
output_dims.emplace_back(indices_pshape[dim_idx]);
}
for (int64_t dim = 0; dim < slice_length; ++dim) {
output_shape[batch_dims + output_indices_length + dim] =
data_pshape[batch_dims + indices_tuple_length + dim];
// Insert dimensions fully taken from the data shape
for (auto dim_idx = batch_dims + indices_tuple_length; dim_idx < data_pshape.size(); ++dim_idx) {
output_dims.emplace_back(data_pshape[dim_idx]);
}
return std::vector<TShape>{TShape(output_shape)};
return {TShape(std::move(output_dims))};
} else {
return std::vector<TShape>{ov::PartialShape::dynamic()};
return {ov::PartialShape::dynamic()};
}
}
} // namespace gather_nd
@ -72,13 +74,15 @@ void shape_infer(const GatherND* op, const std::vector<TShape>& input_shapes, st
// If batch_dims > 1, batch dimensions are need to be fused
auto batch_dims = op->get_batch_dims();
if (batch_dims > 1 && output_shapes[0].rank().is_static()) {
const auto& output_base_shape = output_shapes[0];
std::vector<DimType> output_shape{1};
for (size_t dim = 0; dim < batch_dims; ++dim) {
output_shape[0] *= output_base_shape[dim];
}
output_shape.insert(output_shape.begin() + 1, output_base_shape.begin() + batch_dims, output_base_shape.end());
output_shapes[0] = TShape(output_shape);
auto& output_base_shape = output_shapes[0];
auto output_dims = std::vector<DimType>{output_base_shape[0]};
std::for_each(output_base_shape.begin() + 1,
output_base_shape.begin() + batch_dims,
[&output_dims](const DimType& dim) {
output_dims[0] *= dim;
});
output_dims.insert(output_dims.begin() + 1, output_base_shape.begin() + batch_dims, output_base_shape.end());
output_shapes[0] = TShape(std::move(output_dims));
}
}
} // namespace v5