fix strided slice neg out of bounds ends (#1177)
This commit is contained in:
parent
e05e8893f2
commit
08d8d36667
@ -18,7 +18,13 @@ bool ngraph::pass::UselessStridedSliceEraser::run_on_function(std::shared_ptr<ng
|
|||||||
continue;
|
continue;
|
||||||
if (ss->input(0).get_shape() != ss->output(0).get_shape())
|
if (ss->input(0).get_shape() != ss->output(0).get_shape())
|
||||||
continue;
|
continue;
|
||||||
rewritten |= replace_output_update_name(ss->output(0), ss->input_value(0));
|
|
||||||
|
auto stridesNode = std::dynamic_pointer_cast<ngraph::opset3::Constant>(ss->input_value(3).get_node_shared_ptr());
|
||||||
|
if (stridesNode) {
|
||||||
|
auto strides = stridesNode->cast_vector<int64_t>();
|
||||||
|
if (!std::any_of(strides.begin(), strides.end(), [](int64_t strd) { return strd < 0;}))
|
||||||
|
rewritten |= replace_output_update_name(ss->output(0), ss->input_value(0));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return rewritten;
|
return rewritten;
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
#include <ngraph/function.hpp>
|
#include <ngraph/function.hpp>
|
||||||
#include <ngraph/opsets/opset1.hpp>
|
#include <ngraph/opsets/opset1.hpp>
|
||||||
|
#include <ngraph/opsets/opset3.hpp>
|
||||||
#include <ngraph/pass/constant_folding.hpp>
|
#include <ngraph/pass/constant_folding.hpp>
|
||||||
#include <ngraph_ops/fully_connected.hpp>
|
#include <ngraph_ops/fully_connected.hpp>
|
||||||
#include <transformations/convert_opset1_to_legacy/fc_bias_fusion.hpp>
|
#include <transformations/convert_opset1_to_legacy/fc_bias_fusion.hpp>
|
||||||
@ -97,6 +98,43 @@ TEST(TransformationTests, OptimizeSS_UselessDeletion) {
|
|||||||
ASSERT_TRUE(res.first) << res.second;
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, OptimizeSS_SkipUselessDeletionRevertCase) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{5, 5, 5, 5});
|
||||||
|
auto begin = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 0, 0, 0});
|
||||||
|
auto end = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {-6, -7, -8, -9});
|
||||||
|
auto stride = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {-1});
|
||||||
|
|
||||||
|
std::vector<int64_t> begin_mask = {1, 1, 1, 1};
|
||||||
|
std::vector<int64_t> end_mask = {0, 0, 0, 0};
|
||||||
|
|
||||||
|
auto ss = std::make_shared<ngraph::opset3::StridedSlice>(data, begin, end, stride, begin_mask, end_mask);
|
||||||
|
auto relu = std::make_shared<ngraph::opset3::Relu>(ss);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{relu}, ngraph::ParameterVector{data});
|
||||||
|
ngraph::pass::StridedSliceOptimization().run_on_function(f);
|
||||||
|
ngraph::pass::ConstantFolding().run_on_function(f);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{5, 5, 5, 5});
|
||||||
|
auto begin = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {0, 0, 0, 0});
|
||||||
|
auto end = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {-6, -7, -8, -9});
|
||||||
|
auto stride = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {-1});
|
||||||
|
|
||||||
|
std::vector<int64_t> begin_mask = {1, 1, 1, 1};
|
||||||
|
std::vector<int64_t> end_mask = {0, 0, 0, 0};
|
||||||
|
|
||||||
|
auto ss = std::make_shared<ngraph::opset3::StridedSlice>(data, begin, end, stride, begin_mask, end_mask);
|
||||||
|
auto relu = std::make_shared<ngraph::opset3::Relu>(ss);
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{relu}, ngraph::ParameterVector{data});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
TEST(TransformationTests, OptimizeSS_Usefull_Test) {
|
TEST(TransformationTests, OptimizeSS_Usefull_Test) {
|
||||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
{
|
{
|
||||||
|
@ -0,0 +1,93 @@
|
|||||||
|
// Copyright (C) 2020 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "single_layer_tests/strided_slice.hpp"
|
||||||
|
#include "common_test_utils/test_constants.hpp"
|
||||||
|
|
||||||
|
using namespace LayerTestsDefinitions;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
std::vector<StridedSliceParams> ss_only_test_cases = {
|
||||||
|
StridedSliceParams{ { { 128, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 1, 1, 1 },
|
||||||
|
{ 0, 1, 1 }, { 0, 1, 1 }, { 1, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 128, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 1, 1, 1},
|
||||||
|
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 1, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 1, 12, 100 }, { 0, -1, 0 }, { 0, 0, 0 }, { 1, 1, 1 },
|
||||||
|
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 1, 0 }, { 0, 0, 0 } },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 1, 12, 100 }, { 0, 9, 0 }, { 0, 11, 0 }, { 1, 1, 1 },
|
||||||
|
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 1, 12, 100 }, { 0, 1, 0 }, { 0, -1, 0 }, { 1, 1, 1 },
|
||||||
|
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 1, 12, 100 }, { 0, 9, 0 }, { 0, 7, 0 }, { -1, -1, -1 },
|
||||||
|
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 1, 12, 100 }, { 0, 7, 0 }, { 0, 9, 0 }, { -1, 1, -1 },
|
||||||
|
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 1, 12, 100 }, { 0, 4, 0 }, { 0, 9, 0 }, { -1, 2, -1 },
|
||||||
|
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 1, 12, 100 }, { 0, 4, 0 }, { 0, 10, 0 }, { -1, 2, -1 },
|
||||||
|
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 1, 12, 100 }, { 0, 9, 0 }, { 0, 4, 0 }, { -1, -2, -1 },
|
||||||
|
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 1, 12, 100 }, { 0, 10, 0 }, { 0, 4, 0 }, { -1, -2, -1 },
|
||||||
|
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 1, 12, 100 }, { 0, 11, 0 }, { 0, 0, 0 }, { -1, -2, -1 },
|
||||||
|
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 1, 12, 100 }, { 0, -6, 0 }, { 0, -8, 0 }, { -1, -2, -1 },
|
||||||
|
{ 1, 0, 1 }, { 1, 0, 1 }, { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 1, 12, 100, 1, 1 }, { 0, -1, 0, 0 }, { 0, 0, 0, 0 }, { 1, 1, 1, 1 },
|
||||||
|
{ 1, 0, 1, 0 }, { 1, 0, 1, 0 }, { }, { 0, 1, 0, 1 }, {} },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 2, 2, 2, 2 }, { 0, 0, 0, 0 }, { 2, 2, 2, 2 }, { 1, 1, 1, 1 },
|
||||||
|
{1, 1, 1, 1}, {1, 1, 1, 1}, {}, {}, {} },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 2, 2, 2, 2 }, { 1, 1, 1, 1 }, { 2, 2, 2, 2 }, { 1, 1, 1, 1 },
|
||||||
|
{0, 0, 0, 0}, {1, 1, 1, 1}, {}, {}, {} },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 2, 2, 2, 2 }, { 1, 1, 1, 1 }, { 2, 2, 2, 2 }, { 1, 1, 1, 1 },
|
||||||
|
{0, 0, 0, 0}, {0, 0, 0, 0}, {}, {}, {} },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 2, 2, 4, 3 }, { 0, 0, 0, 0 }, { 2, 2, 4, 3 }, { 1, 1, 2, 1 },
|
||||||
|
{1, 1, 1, 1}, {1, 1, 1, 1}, {}, {}, {} },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 2, 2, 4, 2 }, { 1, 0, 0, 1 }, { 2, 2, 4, 2 }, { 1, 1, 2, 1 },
|
||||||
|
{0, 1, 1, 0}, {1, 1, 0, 0}, {}, {}, {} },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 1, 2, 4, 2 }, { 1, 0, 0, 0 }, { 1, 2, 4, 2 }, { 1, 1, -2, -1 },
|
||||||
|
{1, 1, 1, 1}, {1, 1, 1, 1}, {}, {}, {} },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 2, 2, 4, 2 }, { 1, 0, 0, 0 }, { 1, 2, 4, 2 }, { 1, 1, -2, -1 },
|
||||||
|
{0, 1, 1, 1}, {1, 1, 1, 1}, {}, {}, {} },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 2, 3, 4, 5, 6 }, { 0, 1, 0, 0, 0 }, { 2, 3, 4, 5, 6 }, { 1, 1, 1, 1, 1 },
|
||||||
|
{1, 0, 1, 1, 1}, {1, 0, 1, 1, 1}, {}, {0, 1, 0, 0, 0}, {} },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 10, 12 }, { -1, 1 }, { -9999, 0 }, { -1, 1 },
|
||||||
|
{ 0, 1 }, { 0, 1 }, { 0, 0 }, { 0, 0 }, { 0, 0 } },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}},
|
||||||
|
StridedSliceParams{ { { 5, 5, 5, 5 }, { -1, 0, -1, 0 }, { -50, 0, -60, 0 }, { -1, 1, -1, 1 },
|
||||||
|
{ 0, 0, 0, 0 }, { 0, 1, 0, 1 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 } },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_CPU, {}}
|
||||||
|
};
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(
|
||||||
|
smoke_MKLDNN, StridedSliceLayerTest, ::testing::ValuesIn(ss_only_test_cases),
|
||||||
|
StridedSliceLayerTest::getTestCaseName);
|
||||||
|
|
||||||
|
} // namespace
|
@ -84,6 +84,12 @@ std::vector<StridedSliceParams> ss_only_test_cases = {
|
|||||||
StridedSliceParams{ { { 2, 3, 4, 5, 6 }, { 0, 1, 0, 0, 0 }, { 2, 3, 4, 5, 6 }, { 1, 1, 1, 1, 1 },
|
StridedSliceParams{ { { 2, 3, 4, 5, 6 }, { 0, 1, 0, 0, 0 }, { 2, 3, 4, 5, 6 }, { 1, 1, 1, 1, 1 },
|
||||||
{1, 0, 1, 1, 1}, {1, 0, 1, 1, 1}, {}, {0, 1, 0, 0, 0}, {} },
|
{1, 0, 1, 1, 1}, {1, 0, 1, 1, 1}, {}, {0, 1, 0, 0, 0}, {} },
|
||||||
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_GPU, {}},
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_GPU, {}},
|
||||||
|
StridedSliceParams{ { { 10, 12 }, { -1, 1 }, { -9999, 0 }, { -1, 1 },
|
||||||
|
{ 0, 1 }, { 0, 1 }, { 0, 0 }, { 0, 0 }, { 0, 0 } },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_GPU, {}},
|
||||||
|
StridedSliceParams{ { { 5, 5, 5, 5 }, { -1, 0, -1, 0 }, { -50, 0, -60, 0 }, { -1, 1, -1, 1 },
|
||||||
|
{ 0, 0, 0, 0 }, { 0, 1, 0, 1 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 }, { 0, 0, 0, 0 } },
|
||||||
|
InferenceEngine::Precision::FP32, CommonTestUtils::DEVICE_GPU, {}}
|
||||||
};
|
};
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(
|
INSTANTIATE_TEST_CASE_P(
|
||||||
|
@ -767,6 +767,14 @@ PartialShape ngraph::infer_slice_shape(const Node* node,
|
|||||||
int64_t lb = begin[axis];
|
int64_t lb = begin[axis];
|
||||||
int64_t ub = end[axis];
|
int64_t ub = end[axis];
|
||||||
|
|
||||||
|
// set default value for stride or use given value
|
||||||
|
int64_t stride = 1;
|
||||||
|
if (strides.size() > axis)
|
||||||
|
{
|
||||||
|
stride = strides[axis];
|
||||||
|
}
|
||||||
|
NODE_VALIDATION_CHECK(node, stride != 0, "Stride must be non-zero");
|
||||||
|
|
||||||
// convert negative indexes to positive
|
// convert negative indexes to positive
|
||||||
// take max for this case: if abs(lb) > input_shape[input_shape_idx],then after
|
// take max for this case: if abs(lb) > input_shape[input_shape_idx],then after
|
||||||
// conversion lb < 0
|
// conversion lb < 0
|
||||||
@ -778,22 +786,14 @@ PartialShape ngraph::infer_slice_shape(const Node* node,
|
|||||||
|
|
||||||
if (ub < 0)
|
if (ub < 0)
|
||||||
{
|
{
|
||||||
ub = std::max(input_shape[input_shape_idx].get_length() + ub, int64_t(0));
|
ub = std::max(input_shape[input_shape_idx].get_length() + ub,
|
||||||
|
stride > 0 ? int64_t(0) : int64_t(-1));
|
||||||
}
|
}
|
||||||
|
|
||||||
// apply restrictions when begin or end values more than max possible values.
|
// apply restrictions when begin or end values more than max possible values.
|
||||||
lb = std::min(input_shape[input_shape_idx].get_length(), lb);
|
lb = std::min(input_shape[input_shape_idx].get_length(), lb);
|
||||||
ub = std::min(input_shape[input_shape_idx].get_length(), ub);
|
ub = std::min(input_shape[input_shape_idx].get_length(), ub);
|
||||||
|
|
||||||
// set default value for stride or use given value
|
|
||||||
int64_t stride = 1;
|
|
||||||
if (strides.size() > axis)
|
|
||||||
{
|
|
||||||
stride = strides[axis];
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_VALIDATION_CHECK(node, stride != 0, "Stride must be non-zero");
|
|
||||||
|
|
||||||
int64_t dimension = 0;
|
int64_t dimension = 0;
|
||||||
if (stride < 0)
|
if (stride < 0)
|
||||||
{
|
{
|
||||||
|
@ -188,3 +188,20 @@ TEST(type_prop, strided_slice_default_stride_dynamic_shape_input)
|
|||||||
FAIL() << "Deduced type check failed for unexpected reason";
|
FAIL() << "Deduced type check failed for unexpected reason";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, strided_slice_reverse_out_of_bounds)
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<op::Parameter>(ngraph::element::f32, ngraph::Shape{3, 4, 5});
|
||||||
|
auto begin = op::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {100});
|
||||||
|
auto end = op::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {-100});
|
||||||
|
auto stride = op::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {-1});
|
||||||
|
|
||||||
|
std::vector<int64_t> begin_mask = {0, 0, 0, 0};
|
||||||
|
std::vector<int64_t> end_mask = {0, 0, 0, 0};
|
||||||
|
|
||||||
|
auto ss =
|
||||||
|
std::make_shared<op::v1::StridedSlice>(data, begin, end, stride, begin_mask, end_mask);
|
||||||
|
|
||||||
|
Shape expected{3, 4, 5};
|
||||||
|
EXPECT_EQ(ss->get_output_shape(0), expected);
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user