fix strided slice neg out of bounds ends (#1177)
This commit is contained in:
parent
e05e8893f2
commit
08d8d36667
@ -18,7 +18,13 @@ bool ngraph::pass::UselessStridedSliceEraser::run_on_function(std::shared_ptr<ng
|
||||
continue;
|
||||
if (ss->input(0).get_shape() != ss->output(0).get_shape())
|
||||
continue;
|
||||
rewritten |= replace_output_update_name(ss->output(0), ss->input_value(0));
|
||||
|
||||
auto stridesNode = std::dynamic_pointer_cast<ngraph::opset3::Constant>(ss->input_value(3).get_node_shared_ptr());
|
||||
if (stridesNode) {
|
||||
auto strides = stridesNode->cast_vector<int64_t>();
|
||||
if (!std::any_of(strides.begin(), strides.end(), [](int64_t strd) { return strd < 0;}))
|
||||
rewritten |= replace_output_update_name(ss->output(0), ss->input_value(0));
|
||||
}
|
||||
}
|
||||
return rewritten;
|
||||
}
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include <ngraph_ops/fully_connected.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/fc_bias_fusion.hpp>
|
||||
@ -97,6 +98,43 @@ TEST(TransformationTests, OptimizeSS_UselessDeletion) {
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, OptimizeSS_SkipUselessDeletionRevertCase) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{5, 5, 5, 5});
|
||||
auto begin = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 0, 0, 0});
|
||||
auto end = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {-6, -7, -8, -9});
|
||||
auto stride = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {-1});
|
||||
|
||||
std::vector<int64_t> begin_mask = {1, 1, 1, 1};
|
||||
std::vector<int64_t> end_mask = {0, 0, 0, 0};
|
||||
|
||||
auto ss = std::make_shared<ngraph::opset3::StridedSlice>(data, begin, end, stride, begin_mask, end_mask);
|
||||
auto relu = std::make_shared<ngraph::opset3::Relu>(ss);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{relu}, ngraph::ParameterVector{data});
|
||||
ngraph::pass::StridedSliceOptimization().run_on_function(f);
|
||||
ngraph::pass::ConstantFolding().run_on_function(f);
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{5, 5, 5, 5});
|
||||
auto begin = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 0, 0, 0});
|
||||
auto end = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {-6, -7, -8, -9});
|
||||
auto stride = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {-1});
|
||||
|
||||
std::vector<int64_t> begin_mask = {1, 1, 1, 1};
|
||||
std::vector<int64_t> end_mask = {0, 0, 0, 0};
|
||||
|
||||
auto ss = std::make_shared<ngraph::opset3::StridedSlice>(data, begin, end, stride, begin_mask, end_mask);
|
||||
auto relu = std::make_shared<ngraph::opset3::Relu>(ss);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{relu}, ngraph::ParameterVector{data});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, OptimizeSS_Usefull_Test) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
|
@ -0,0 +1,93 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "single_layer_tests/strided_slice.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
|
||||
std::vector<StridedSliceParams> ss_only_test_cases = {
|
||||
StridedSliceParams{ { { 128, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 1, 1, 1 },
|
||||
{ 0, 1, 1 }, { 0, 1, 1 }, { 1, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 128, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 1, 1, 1},
|
||||
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 1, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 1, 12, 100 }, { 0, -1, 0 }, { 0, 0, 0 }, { 1, 1, 1 },
|
||||
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 1, 0 }, { 0, 0, 0 } },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 1, 12, 100 }, { 0, 9, 0 }, { 0, 11, 0 }, { 1, 1, 1 },
|
||||
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 1, 12, 100 }, { 0, 1, 0 }, { 0, -1, 0 }, { 1, 1, 1 },
|
||||
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 1, 12, 100 }, { 0, 9, 0 }, { 0, 7, 0 }, { -1, -1, -1 },
|
||||
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 1, 12, 100 }, { 0, 7, 0 }, { 0, 9, 0 }, { -1, 1, -1 },
|
||||
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 1, 12, 100 }, { 0, 4, 0 }, { 0, 9, 0 }, { -1, 2, -1 },
|
||||
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 1, 12, 100 }, { 0, 4, 0 }, { 0, 10, 0 }, { -1, 2, -1 },
|
||||
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 1, 12, 100 }, { 0, 9, 0 }, { 0, 4, 0 }, { -1, -2, -1 },
|
||||
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 1, 12, 100 }, { 0, 10, 0 }, { 0, 4, 0 }, { -1, -2, -1 },
|
||||
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 1, 12, 100 }, { 0, 11, 0 }, { 0, 0, 0 }, { -1, -2, -1 },
|
||||
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 1, 12, 100 }, { 0, -6, 0 }, { 0, -8, 0 }, { -1, -2, -1 },
|
||||
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 1, 12, 100, 1, 1 }, { 0, -1, 0, 0 }, { 0, 0, 0, 0 }, { 1, 1, 1, 1 },
|
||||
{ 1, 0, 1, 0 }, { 1, 0, 1, 0 }, { }, { 0, 1, 0, 1 }, {} },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 2, 2, 2, 2 }, { 0, 0, 0, 0 }, { 2, 2, 2, 2 }, { 1, 1, 1, 1 },
|
||||
{1, 1, 1, 1}, {1, 1, 1, 1}, {}, {}, {} },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 2, 2, 2, 2 }, { 1, 1, 1, 1 }, { 2, 2, 2, 2 }, { 1, 1, 1, 1 },
|
||||
{0, 0, 0, 0}, {1, 1, 1, 1}, {}, {}, {} },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 2, 2, 2, 2 }, { 1, 1, 1, 1 }, { 2, 2, 2, 2 }, { 1, 1, 1, 1 },
|
||||
{0, 0, 0, 0}, {0, 0, 0, 0}, {}, {}, {} },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 2, 2, 4, 3 }, { 0, 0, 0, 0 }, { 2, 2, 4, 3 }, { 1, 1, 2, 1 },
|
||||
{1, 1, 1, 1}, {1, 1, 1, 1}, {}, {}, {} },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 2, 2, 4, 2 }, { 1, 0, 0, 1 }, { 2, 2, 4, 2 }, { 1, 1, 2, 1 },
|
||||
{0, 1, 1, 0}, {1, 1, 0, 0}, {}, {}, {} },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 1, 2, 4, 2 }, { 1, 0, 0, 0 }, { 1, 2, 4, 2 }, { 1, 1, -2, -1 },
|
||||
{1, 1, 1, 1}, {1, 1, 1, 1}, {}, {}, {} },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 2, 2, 4, 2 }, { 1, 0, 0, 0 }, { 1, 2, 4, 2 }, { 1, 1, -2, -1 },
|
||||
{0, 1, 1, 1}, {1, 1, 1, 1}, {}, {}, {} },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 2, 3, 4, 5, 6 }, { 0, 1, 0, 0, 0 }, { 2, 3, 4, 5, 6 }, { 1, 1, 1, 1, 1 },
|
||||
{1, 0, 1, 1, 1}, {1, 0, 1, 1, 1}, {}, {0, 1, 0, 0, 0}, {} },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 10, 12 }, { -1, 1 }, { -9999, 0 }, { -1, 1 },
|
||||
{ 0, 1 }, { 0, 1 }, { 0, 0 }, { 0, 0 }, { 0, 0 } },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||
StridedSliceParams{ { { 5, 5, 5, 5 }, { -1, 0, -1, 0 }, { -50, 0, -60, 0 }, { -1, 1, -1, 1 },
|
||||
{ 0, 0, 0, 0 }, { 0, 1, 0, 1 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 } },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
smoke_MKLDNN, StridedSliceLayerTest, ::testing::ValuesIn(ss_only_test_cases),
|
||||
StridedSliceLayerTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
@ -84,6 +84,12 @@ std::vector<StridedSliceParams> ss_only_test_cases = {
|
||||
StridedSliceParams{ { { 2, 3, 4, 5, 6 }, { 0, 1, 0, 0, 0 }, { 2, 3, 4, 5, 6 }, { 1, 1, 1, 1, 1 },
|
||||
{1, 0, 1, 1, 1}, {1, 0, 1, 1, 1}, {}, {0, 1, 0, 0, 0}, {} },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_GPU, {}},
|
||||
StridedSliceParams{ { { 10, 12 }, { -1, 1 }, { -9999, 0 }, { -1, 1 },
|
||||
{ 0, 1 }, { 0, 1 }, { 0, 0 }, { 0, 0 }, { 0, 0 } },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_GPU, {}},
|
||||
StridedSliceParams{ { { 5, 5, 5, 5 }, { -1, 0, -1, 0 }, { -50, 0, -60, 0 }, { -1, 1, -1, 1 },
|
||||
{ 0, 0, 0, 0 }, { 0, 1, 0, 1 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 } },
|
||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_GPU, {}}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
|
@ -767,6 +767,14 @@ PartialShape ngraph::infer_slice_shape(const Node* node,
|
||||
int64_t lb = begin[axis];
|
||||
int64_t ub = end[axis];
|
||||
|
||||
// set default value for stride or use given value
|
||||
int64_t stride = 1;
|
||||
if (strides.size() > axis)
|
||||
{
|
||||
stride = strides[axis];
|
||||
}
|
||||
NODE_VALIDATION_CHECK(node, stride != 0, "Stride must be non-zero");
|
||||
|
||||
// convert negative indexes to positive
|
||||
// take max for this case: if abs(lb) > input_shape[input_shape_idx],then after
|
||||
// conversion lb < 0
|
||||
@ -778,22 +786,14 @@ PartialShape ngraph::infer_slice_shape(const Node* node,
|
||||
|
||||
if (ub < 0)
|
||||
{
|
||||
ub = std::max(input_shape[input_shape_idx].get_length() + ub, int64_t(0));
|
||||
ub = std::max(input_shape[input_shape_idx].get_length() + ub,
|
||||
stride > 0 ? int64_t(0) : int64_t(-1));
|
||||
}
|
||||
|
||||
// apply restrictions when begin or end values more than max possible values.
|
||||
lb = std::min(input_shape[input_shape_idx].get_length(), lb);
|
||||
ub = std::min(input_shape[input_shape_idx].get_length(), ub);
|
||||
|
||||
// set default value for stride or use given value
|
||||
int64_t stride = 1;
|
||||
if (strides.size() > axis)
|
||||
{
|
||||
stride = strides[axis];
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(node, stride != 0, "Stride must be non-zero");
|
||||
|
||||
int64_t dimension = 0;
|
||||
if (stride < 0)
|
||||
{
|
||||
|
@ -188,3 +188,20 @@ TEST(type_prop, strided_slice_default_stride_dynamic_shape_input)
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, strided_slice_reverse_out_of_bounds)
|
||||
{
|
||||
auto data = std::make_shared<op::Parameter>(ngraph::element::f32, ngraph::Shape{3, 4, 5});
|
||||
auto begin = op::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {100});
|
||||
auto end = op::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {-100});
|
||||
auto stride = op::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {-1});
|
||||
|
||||
std::vector<int64_t> begin_mask = {0, 0, 0, 0};
|
||||
std::vector<int64_t> end_mask = {0, 0, 0, 0};
|
||||
|
||||
auto ss =
|
||||
std::make_shared<op::v1::StridedSlice>(data, begin, end, stride, begin_mask, end_mask);
|
||||
|
||||
Shape expected{3, 4, 5};
|
||||
EXPECT_EQ(ss->get_output_shape(0), expected);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user