[IE][VPU][DTS]: shrink mask support for StridedSlice and test (#3835)

This commit is contained in:
Nikita Kudriavtsev 2021-01-19 11:46:59 +03:00 committed by GitHub
parent 2bfc941cf1
commit ab66eab652
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 8 deletions

View File

@ -33,6 +33,7 @@ std::shared_ptr<ngraph::Node> calculate_output_shape(
const std::vector<int64_t> & strides,
const ngraph::AxisSet & begin_mask,
const ngraph::AxisSet & end_mask,
const ngraph::AxisSet & shrink_mask,
const ngraph::Output<ngraph::Node> & input_shape) {
const auto& shape_type = input_shape.get_element_type();
@ -44,6 +45,9 @@ std::shared_ptr<ngraph::Node> calculate_output_shape(
ngraph::OutputVector output_dimensions;
for (size_t axis = 0; axis < begin.size(); ++axis) {
if (shrink_mask.count(axis)) {
continue;
}
auto lb = begin[axis], ub = end[axis], stride = strides[axis];
ngraph::Output<ngraph::Node> lower_bound = ngraph::opset3::Constant::create(shape_type, {1}, {lb});
@ -108,14 +112,11 @@ std::shared_ptr<ngraph::Node> calculate_output_shape(
output_dimensions.push_back(output_dimension);
}
if (output_dimensions.size() < inputShapeRank) {
output_dimensions.push_back(gatherShapeElements(input_shape, output_dimensions.size(), inputShapeRank - output_dimensions.size()));
size_t processed_dims_count = output_dimensions.size() + shrink_mask.size();
if (processed_dims_count < inputShapeRank) {
output_dimensions.push_back(gatherShapeElements(input_shape, processed_dims_count, inputShapeRank - processed_dims_count));
}
VPU_THROW_UNLESS(output_dimensions.size() == inputShapeRank,
"output shape rank {} must be equal to input shape rank {} for DTS of StridedSlice",
output_dimensions.size(), inputShapeRank);
const auto output_shape = std::make_shared<ngraph::opset3::Concat>(output_dimensions, 0);
return output_shape;
}
@ -132,8 +133,6 @@ void dynamicToStaticShapeStridedSlice(std::shared_ptr<ngraph::Node> target) {
const auto all_zero = [](const std::vector<int64_t> & v) {return std::all_of(v.cbegin(), v.cend(), [](const int64_t & i){return i == 0;});};
VPU_THROW_UNLESS(all_zero(stridedSlice->get_new_axis_mask()),
"dynamicToStaticShapeStridedSlice transformation is not applicable for {}, new_axis_mask expected to be zeros", target);
VPU_THROW_UNLESS(all_zero(stridedSlice->get_shrink_axis_mask()),
"dynamicToStaticShapeStridedSlice transformation is not applicable for {}, shrink_axis_mask expected to be zeros", target);
VPU_THROW_UNLESS(all_zero(stridedSlice->get_ellipsis_mask()),
"dynamicToStaticShapeStridedSlice transformation is not applicable for {}, ellipsis_mask expected to be zeros", target);
@ -152,6 +151,7 @@ void dynamicToStaticShapeStridedSlice(std::shared_ptr<ngraph::Node> target) {
get_i64_vector_from_const(stridedSlice->input_value(3).get_node_shared_ptr()),
convert_mask_to_axis_set(stridedSlice->get_begin_mask()),
convert_mask_to_axis_set(stridedSlice->get_end_mask()),
convert_mask_to_axis_set(stridedSlice->get_shrink_axis_mask()),
input_shape);
const auto copied = stridedSlice->clone_with_new_inputs(target->input_values());

View File

@ -51,6 +51,8 @@ TEST_P(DSR_StridedSlice, CompareWithReference) {
}
std::vector<StridedSliceParams> testCases = {
{ { { 2, 3, 4, 5, 6 }, { 2, 3, 4, 5, 3 } },
{ 0, 1, 0, 0, 0 }, { 2, 3, 4, 5, -1 }, { 1, 1, 1, 1, 1 }, {1, 0, 1, 1, 1}, {1, 0, 1, 1, 1}, {}, {0, 0, 0, 0, 1}, {} },
{ { { 800, 4 }, { 1000, 4 } }, { 0, 0 }, { -1, 0 }, { 2, 1 }, { 1, 0 }, { 0, 1 }, {}, {}, {} },
{ { { 1, 12, 80 }, { 1, 12, 100 } }, { 0, 9, 0 }, { 0, 11, 0 }, { 1, 1, 1 }, { 1, 0, 1 }, { 1, 0, 1 }, {}, {}, {} },
{ { { 1, 7, 80 }, { 1, 12, 100 } }, { 0, 1, 0 }, { 0, -1, 0 }, { 1, 1, 1 }, { 1, 0, 1 }, { 1, 0, 1 }, {}, {}, {} },