diff --git a/src/core/shape_inference/include/gather_nd_shape_inference.hpp b/src/core/shape_inference/include/gather_nd_shape_inference.hpp index 3b79e602679..210e6706628 100644 --- a/src/core/shape_inference/include/gather_nd_shape_inference.hpp +++ b/src/core/shape_inference/include/gather_nd_shape_inference.hpp @@ -37,29 +37,31 @@ std::vector gather_nd_base_shape_infer(const TOp* op, const std::vector< NODE_VALIDATION_CHECK( op, - static_cast(indices_tuple_length + op->get_batch_dims()) <= data_pshape.rank().get_length(), + cmp::le(indices_tuple_length + op->get_batch_dims(), data_pshape.rank().get_length()), "Length of a tuple with indices must not exceed a rank of data tensor excluding batch dimensions."); - int64_t slice_length = data_pshape.rank().get_length() - indices_tuple_length - batch_dims; - int64_t output_indices_length = indices_pshape.rank().get_length() - batch_dims - 1; - auto output_rank = output_indices_length + slice_length; + const auto slice_length = data_pshape.size() - indices_tuple_length - batch_dims; + const auto output_indices_length = indices_pshape.size() - batch_dims - 1; using DimType = typename TShape::value_type; - std::vector output_shape(output_rank + batch_dims); - for (size_t dim = 0; dim < batch_dims; ++dim) { + std::vector output_dims(batch_dims); + output_dims.reserve(batch_dims + output_indices_length + slice_length); + // Merge batch dimensions + for (size_t dim_idx = 0; dim_idx < batch_dims; ++dim_idx) { NODE_VALIDATION_CHECK(op, - DimType::merge(output_shape[dim], data_pshape[dim], indices_pshape[dim]), + DimType::merge(output_dims[dim_idx], data_pshape[dim_idx], indices_pshape[dim_idx]), "Batch dimensions of data and indices must be the same."); } - for (int64_t dim = 0; dim < output_indices_length; ++dim) { - output_shape[batch_dims + dim] = indices_pshape[batch_dims + dim]; + // Insert middle dimensions from the indices shape + for (auto dim_idx = batch_dims; dim_idx < indices_pshape.size() - 1; ++dim_idx) { + output_dims.emplace_back(indices_pshape[dim_idx]); } - for (int64_t dim = 0; dim < slice_length; ++dim) { - output_shape[batch_dims + output_indices_length + dim] = - data_pshape[batch_dims + indices_tuple_length + dim]; + // Insert dimensions fully taken from the data shape + for (auto dim_idx = batch_dims + indices_tuple_length; dim_idx < data_pshape.size(); ++dim_idx) { + output_dims.emplace_back(data_pshape[dim_idx]); } - return std::vector{TShape(output_shape)}; + return {TShape(std::move(output_dims))}; } else { - return std::vector{ov::PartialShape::dynamic()}; + return {ov::PartialShape::dynamic()}; } } } // namespace gather_nd @@ -72,13 +74,15 @@ void shape_infer(const GatherND* op, const std::vector& input_shapes, st // If batch_dims > 1, batch dimensions are need to be fused auto batch_dims = op->get_batch_dims(); if (batch_dims > 1 && output_shapes[0].rank().is_static()) { - const auto& output_base_shape = output_shapes[0]; - std::vector output_shape{1}; - for (size_t dim = 0; dim < batch_dims; ++dim) { - output_shape[0] *= output_base_shape[dim]; - } - output_shape.insert(output_shape.begin() + 1, output_base_shape.begin() + batch_dims, output_base_shape.end()); - output_shapes[0] = TShape(output_shape); + auto& output_base_shape = output_shapes[0]; + auto output_dims = std::vector{output_base_shape[0]}; + std::for_each(output_base_shape.begin() + 1, + output_base_shape.begin() + batch_dims, + [&output_dims](const DimType& dim) { + output_dims[0] *= dim; + }); + output_dims.insert(output_dims.begin() + 1, output_base_shape.begin() + batch_dims, output_base_shape.end()); + output_shapes[0] = TShape(std::move(output_dims)); } } } // namespace v5