[IE][VPU][DTS]: shrink mask support for StridedSlice and test (#3835)
This commit is contained in:
parent
2bfc941cf1
commit
ab66eab652
@ -33,6 +33,7 @@ std::shared_ptr<ngraph::Node> calculate_output_shape(
|
||||
const std::vector<int64_t> & strides,
|
||||
const ngraph::AxisSet & begin_mask,
|
||||
const ngraph::AxisSet & end_mask,
|
||||
const ngraph::AxisSet & shrink_mask,
|
||||
const ngraph::Output<ngraph::Node> & input_shape) {
|
||||
const auto& shape_type = input_shape.get_element_type();
|
||||
|
||||
@ -44,6 +45,9 @@ std::shared_ptr<ngraph::Node> calculate_output_shape(
|
||||
|
||||
ngraph::OutputVector output_dimensions;
|
||||
for (size_t axis = 0; axis < begin.size(); ++axis) {
|
||||
if (shrink_mask.count(axis)) {
|
||||
continue;
|
||||
}
|
||||
auto lb = begin[axis], ub = end[axis], stride = strides[axis];
|
||||
|
||||
ngraph::Output<ngraph::Node> lower_bound = ngraph::opset3::Constant::create(shape_type, {1}, {lb});
|
||||
@ -108,14 +112,11 @@ std::shared_ptr<ngraph::Node> calculate_output_shape(
|
||||
output_dimensions.push_back(output_dimension);
|
||||
}
|
||||
|
||||
if (output_dimensions.size() < inputShapeRank) {
|
||||
output_dimensions.push_back(gatherShapeElements(input_shape, output_dimensions.size(), inputShapeRank - output_dimensions.size()));
|
||||
size_t processed_dims_count = output_dimensions.size() + shrink_mask.size();
|
||||
if (processed_dims_count < inputShapeRank) {
|
||||
output_dimensions.push_back(gatherShapeElements(input_shape, processed_dims_count, inputShapeRank - processed_dims_count));
|
||||
}
|
||||
|
||||
VPU_THROW_UNLESS(output_dimensions.size() == inputShapeRank,
|
||||
"output shape rank {} must be equal to input shape rank {} for DTS of StridedSlice",
|
||||
output_dimensions.size(), inputShapeRank);
|
||||
|
||||
const auto output_shape = std::make_shared<ngraph::opset3::Concat>(output_dimensions, 0);
|
||||
return output_shape;
|
||||
}
|
||||
@ -132,8 +133,6 @@ void dynamicToStaticShapeStridedSlice(std::shared_ptr<ngraph::Node> target) {
|
||||
const auto all_zero = [](const std::vector<int64_t> & v) {return std::all_of(v.cbegin(), v.cend(), [](const int64_t & i){return i == 0;});};
|
||||
VPU_THROW_UNLESS(all_zero(stridedSlice->get_new_axis_mask()),
|
||||
"dynamicToStaticShapeStridedSlice transformation is not applicable for {}, new_axis_mask expected to be zeros", target);
|
||||
VPU_THROW_UNLESS(all_zero(stridedSlice->get_shrink_axis_mask()),
|
||||
"dynamicToStaticShapeStridedSlice transformation is not applicable for {}, shrink_axis_mask expected to be zeros", target);
|
||||
VPU_THROW_UNLESS(all_zero(stridedSlice->get_ellipsis_mask()),
|
||||
"dynamicToStaticShapeStridedSlice transformation is not applicable for {}, ellipsis_mask expected to be zeros", target);
|
||||
|
||||
@ -152,6 +151,7 @@ void dynamicToStaticShapeStridedSlice(std::shared_ptr<ngraph::Node> target) {
|
||||
get_i64_vector_from_const(stridedSlice->input_value(3).get_node_shared_ptr()),
|
||||
convert_mask_to_axis_set(stridedSlice->get_begin_mask()),
|
||||
convert_mask_to_axis_set(stridedSlice->get_end_mask()),
|
||||
convert_mask_to_axis_set(stridedSlice->get_shrink_axis_mask()),
|
||||
input_shape);
|
||||
|
||||
const auto copied = stridedSlice->clone_with_new_inputs(target->input_values());
|
||||
|
@ -51,6 +51,8 @@ TEST_P(DSR_StridedSlice, CompareWithReference) {
|
||||
}
|
||||
|
||||
std::vector<StridedSliceParams> testCases = {
|
||||
{ { { 2, 3, 4, 5, 6 }, { 2, 3, 4, 5, 3 } },
|
||||
{ 0, 1, 0, 0, 0 }, { 2, 3, 4, 5, -1 }, { 1, 1, 1, 1, 1 }, {1, 0, 1, 1, 1}, {1, 0, 1, 1, 1}, {}, {0, 0, 0, 0, 1}, {} },
|
||||
{ { { 800, 4 }, { 1000, 4 } }, { 0, 0 }, { -1, 0 }, { 2, 1 }, { 1, 0 }, { 0, 1 }, {}, {}, {} },
|
||||
{ { { 1, 12, 80 }, { 1, 12, 100 } }, { 0, 9, 0 }, { 0, 11, 0 }, { 1, 1, 1 }, { 1, 0, 1 }, { 1, 0, 1 }, {}, {}, {} },
|
||||
{ { { 1, 7, 80 }, { 1, 12, 100 } }, { 0, 1, 0 }, { 0, -1, 0 }, { 1, 1, 1 }, { 1, 0, 1 }, { 1, 0, 1 }, {}, {}, {} },
|
||||
|
Loading…
Reference in New Issue
Block a user