[IE][VPU][nGraph]: Fixes StridedSlice DTS (#1328)
* In case of Begin/End/Stride inputs of StridedSlice have rank less than input data rank - remaining dimensions must be kept unchanged. * Previous, implementation had UB in such cases - out of bound vector element access Signed-off-by: Gladilov, Gleb <gleb.gladilov@intel.com>
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
#include "ngraph/opsets/opset3.hpp"
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
|
||||
namespace vpu {
|
||||
|
||||
@@ -32,10 +33,16 @@ std::shared_ptr<ngraph::Node> calculate_output_shape(
|
||||
const ngraph::AxisSet & begin_mask,
|
||||
const ngraph::AxisSet & end_mask,
|
||||
const ngraph::Output<ngraph::Node> & input_shape) {
|
||||
const auto shape_type = input_shape.get_element_type();
|
||||
const auto& shape_type = input_shape.get_element_type();
|
||||
|
||||
VPU_THROW_UNLESS(begin.size() == end.size() && begin.size() == strides.size(),
|
||||
"Begin, end and strides inputs must be of the same size, but {}, {} and {} given accordingly", begin.size(), end.size(), strides.size());
|
||||
const auto inputShapeRank = input_shape.get_partial_shape()[0].get_length();
|
||||
VPU_THROW_UNLESS(inputShapeRank >= begin.size(),
|
||||
"Input shape rank must not be less than begin/end/strides size, but {} and {} given accordingly", inputShapeRank, begin.size());
|
||||
|
||||
ngraph::OutputVector output_dimensions;
|
||||
for (int64_t axis = 0; axis < input_shape.get_partial_shape()[0].get_length(); ++axis) {
|
||||
for (int64_t axis = 0; axis < begin.size(); ++axis) {
|
||||
auto lb = begin[axis], ub = end[axis], stride = strides[axis];
|
||||
|
||||
ngraph::Output<ngraph::Node> lower_bound = ngraph::opset3::Constant::create(shape_type, {1}, {lb});
|
||||
@@ -99,6 +106,22 @@ std::shared_ptr<ngraph::Node> calculate_output_shape(
|
||||
}
|
||||
output_dimensions.push_back(output_dimension);
|
||||
}
|
||||
|
||||
if (output_dimensions.size() < inputShapeRank) {
|
||||
std::vector<std::int64_t> indices(inputShapeRank - output_dimensions.size());
|
||||
std::iota(indices.begin(), indices.end(), static_cast<std::int64_t>(output_dimensions.size()));
|
||||
|
||||
const auto tail = std::make_shared<ngraph::opset3::Gather>(
|
||||
input_shape,
|
||||
ngraph::opset3::Constant::create(ngraph::element::i64, {indices.size()}, indices),
|
||||
ngraph::opset3::Constant::create(shape_type, {}, {0}));
|
||||
output_dimensions.push_back(tail);
|
||||
}
|
||||
|
||||
VPU_THROW_UNLESS(output_dimensions.size() == inputShapeRank,
|
||||
"output shape rank {} must be equal to input shape rank {} for DTS of StridedSlice",
|
||||
output_dimensions.size(), inputShapeRank);
|
||||
|
||||
const auto output_shape = std::make_shared<ngraph::opset3::Concat>(output_dimensions, 0);
|
||||
return output_shape;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user