[ShapeInference] GatherND improvements review (#15416)
* GatherND shape infer base refactor * Use accumulate for dims fusing * Replace accumulate with for_each
This commit is contained in:
parent
15fa07eb5a
commit
5388a6f2af
@ -37,29 +37,31 @@ std::vector<TShape> gather_nd_base_shape_infer(const TOp* op, const std::vector<
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
op,
|
||||
static_cast<int64_t>(indices_tuple_length + op->get_batch_dims()) <= data_pshape.rank().get_length(),
|
||||
cmp::le(indices_tuple_length + op->get_batch_dims(), data_pshape.rank().get_length()),
|
||||
"Length of a tuple with indices must not exceed a rank of data tensor excluding batch dimensions.");
|
||||
int64_t slice_length = data_pshape.rank().get_length() - indices_tuple_length - batch_dims;
|
||||
int64_t output_indices_length = indices_pshape.rank().get_length() - batch_dims - 1;
|
||||
auto output_rank = output_indices_length + slice_length;
|
||||
const auto slice_length = data_pshape.size() - indices_tuple_length - batch_dims;
|
||||
const auto output_indices_length = indices_pshape.size() - batch_dims - 1;
|
||||
|
||||
using DimType = typename TShape::value_type;
|
||||
std::vector<DimType> output_shape(output_rank + batch_dims);
|
||||
for (size_t dim = 0; dim < batch_dims; ++dim) {
|
||||
std::vector<DimType> output_dims(batch_dims);
|
||||
output_dims.reserve(batch_dims + output_indices_length + slice_length);
|
||||
// Merge batch dimensions
|
||||
for (size_t dim_idx = 0; dim_idx < batch_dims; ++dim_idx) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
DimType::merge(output_shape[dim], data_pshape[dim], indices_pshape[dim]),
|
||||
DimType::merge(output_dims[dim_idx], data_pshape[dim_idx], indices_pshape[dim_idx]),
|
||||
"Batch dimensions of data and indices must be the same.");
|
||||
}
|
||||
for (int64_t dim = 0; dim < output_indices_length; ++dim) {
|
||||
output_shape[batch_dims + dim] = indices_pshape[batch_dims + dim];
|
||||
// Insert middle dimensions from the indices shape
|
||||
for (auto dim_idx = batch_dims; dim_idx < indices_pshape.size() - 1; ++dim_idx) {
|
||||
output_dims.emplace_back(indices_pshape[dim_idx]);
|
||||
}
|
||||
for (int64_t dim = 0; dim < slice_length; ++dim) {
|
||||
output_shape[batch_dims + output_indices_length + dim] =
|
||||
data_pshape[batch_dims + indices_tuple_length + dim];
|
||||
// Insert dimensions fully taken from the data shape
|
||||
for (auto dim_idx = batch_dims + indices_tuple_length; dim_idx < data_pshape.size(); ++dim_idx) {
|
||||
output_dims.emplace_back(data_pshape[dim_idx]);
|
||||
}
|
||||
return std::vector<TShape>{TShape(output_shape)};
|
||||
return {TShape(std::move(output_dims))};
|
||||
} else {
|
||||
return std::vector<TShape>{ov::PartialShape::dynamic()};
|
||||
return {ov::PartialShape::dynamic()};
|
||||
}
|
||||
}
|
||||
} // namespace gather_nd
|
||||
@ -72,13 +74,15 @@ void shape_infer(const GatherND* op, const std::vector<TShape>& input_shapes, st
|
||||
// If batch_dims > 1, batch dimensions are need to be fused
|
||||
auto batch_dims = op->get_batch_dims();
|
||||
if (batch_dims > 1 && output_shapes[0].rank().is_static()) {
|
||||
const auto& output_base_shape = output_shapes[0];
|
||||
std::vector<DimType> output_shape{1};
|
||||
for (size_t dim = 0; dim < batch_dims; ++dim) {
|
||||
output_shape[0] *= output_base_shape[dim];
|
||||
}
|
||||
output_shape.insert(output_shape.begin() + 1, output_base_shape.begin() + batch_dims, output_base_shape.end());
|
||||
output_shapes[0] = TShape(output_shape);
|
||||
auto& output_base_shape = output_shapes[0];
|
||||
auto output_dims = std::vector<DimType>{output_base_shape[0]};
|
||||
std::for_each(output_base_shape.begin() + 1,
|
||||
output_base_shape.begin() + batch_dims,
|
||||
[&output_dims](const DimType& dim) {
|
||||
output_dims[0] *= dim;
|
||||
});
|
||||
output_dims.insert(output_dims.begin() + 1, output_base_shape.begin() + batch_dims, output_base_shape.end());
|
||||
output_shapes[0] = TShape(std::move(output_dims));
|
||||
}
|
||||
}
|
||||
} // namespace v5
|
||||
|
Loading…
Reference in New Issue
Block a user