[CPU] StridedSlice fix execution for empty output tensor (#16045)
* stridedslice skip execution with 0 dims * use isExecutable & add subgraph tests * remove useless code
This commit is contained in:
parent
cb7eeadd62
commit
0e9b133de5
@ -281,7 +281,7 @@ void StridedSlice::initSupportedPrimitiveDescriptors() {
|
||||
}
|
||||
|
||||
bool StridedSlice::isExecutable() const {
|
||||
return !isInputTensorAtPortEmpty(0);
|
||||
return !isInputTensorAtPortEmpty(0) && !isOutputTensorAtPortEmpty(0);
|
||||
}
|
||||
|
||||
void StridedSlice::createPrimitive() {
|
||||
@ -310,7 +310,6 @@ void StridedSlice::prepareParams() {
|
||||
dstMemory.push_back(getChildEdgeAt(i)->getMemoryPtr());
|
||||
}
|
||||
}
|
||||
|
||||
execPtr = std::make_shared<StridedSliceCommonExecutor>(attrs, srcMemory, dstMemory, errorPrefix);
|
||||
}
|
||||
|
||||
|
@ -213,10 +213,10 @@ const std::vector<Shape> inputShapesStatic4D = {
|
||||
|
||||
const std::vector<InputShape> inputShapesDynamic4D = {
|
||||
{{-1, -1, -1, -1},
|
||||
{{ 1, 5, 32, 32 }, { 2, 5, 32, 32 }, { 1, 5, 64, 64 }}},
|
||||
{{ 1, 5, 32, 32 }, { 2, 5, 32, 32 }, { 1, 5, 64, 64 }, {0, 0, 0, 0}}},
|
||||
|
||||
{{-1, 5, -1, -1},
|
||||
{{ 1, 5, 32, 32 }, { 2, 5, 32, 32 }, { 3, 5, 32, 36 }}},
|
||||
{{ 1, 5, 32, 32 }, { 2, 5, 32, 32 }, { 3, 5, 32, 36 }, {0, 5, 0, 0}}},
|
||||
|
||||
{{{1, 5}, 5, {32, 64}, {32, 64}},
|
||||
{{ 2, 5, 32, 32 }, { 1, 5, 48, 32 }, { 5, 5, 32, 32 }}},
|
||||
@ -352,10 +352,10 @@ const std::vector<Shape> inputShapesStatic5D = {
|
||||
|
||||
const std::vector<InputShape> inputShapesDynamic5D = {
|
||||
{{-1, -1, -1, -1, -1},
|
||||
{{ 1, 5, 32, 32, 32 }, { 1, 5, 32, 32, 48 }, { 1, 5, 64, 64, 64 }, { 1, 10, 32, 32, 32 }}},
|
||||
{{ 1, 5, 32, 32, 32 }, { 1, 5, 32, 32, 48 }, { 1, 5, 64, 64, 64 }, { 1, 10, 32, 32, 32 }, {0, 0, 0, 0, 0}}},
|
||||
|
||||
{{-1, 5, -1, -1, -1},
|
||||
{{ 1, 5, 32, 32, 48 }, { 1, 5, 32, 48, 32 }, { 1, 5, 48, 32, 32 }}},
|
||||
{{ 1, 5, 32, 32, 48 }, { 1, 5, 32, 48, 32 }, { 1, 5, 48, 32, 32 }, {0, 5, 0, 0, 0}}},
|
||||
|
||||
{{{1, 5}, 5, {32, 64}, {32, 64}, {32, 64}},
|
||||
{{ 2, 5, 32, 32, 32 }, { 1, 5, 48, 32, 32 }, { 5, 5, 32, 32, 48 }}},
|
||||
|
@ -0,0 +1,56 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
#include "ngraph_functions/utils/ngraph_helpers.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
using namespace ov::test;
|
||||
using namespace ngraph;
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
/*
|
||||
param1 [56] param2 [-1, -1, 768] (dynamic shape)
|
||||
\ |
|
||||
\ |
|
||||
\ shapeOf [4] (variable)
|
||||
\ |
|
||||
\ |
|
||||
\ Gather (get dynamic element) [1] (static value)
|
||||
\ |
|
||||
\ | OtherConstants
|
||||
\ | /
|
||||
StridedSlice [47] (Static output shape)
|
||||
|
|
||||
|
|
||||
Result
|
||||
*/
|
||||
|
||||
class StridedSliceZeroDimsTest : public SubgraphBaseTest {
|
||||
public:
|
||||
void SetUp() override {
|
||||
targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||
InputShape inpShape0 = {{}, {{56}}};
|
||||
InputShape inpShape1 = {{-1, -1, 768}, {{1, 544, 768}}};
|
||||
init_input_shapes({inpShape0, inpShape1});
|
||||
auto inputParams = builder::makeDynamicParams(element::f32, inputDynamicShapes);
|
||||
auto end = builder::makeConstant(element::i64, {1}, std::vector<int64_t>{2147483647});
|
||||
auto stride = builder::makeConstant(element::i64, {1}, std::vector<int64_t>{1});
|
||||
auto indices = builder::makeConstant(element::i64, {1}, std::vector<int64_t>{1});
|
||||
auto axes = builder::makeConstant(element::i64, {1}, std::vector<int64_t>{0});
|
||||
auto shapeOf = std::make_shared<opset9::ShapeOf>(inputParams[1]);
|
||||
auto gather = std::make_shared<opset9::Gather>(shapeOf, indices, axes);
|
||||
auto strided_slice = builder::makeStridedSlice(inputParams.front(), gather, end, stride, element::f32, {0}, {0});
|
||||
NodeVector results{strided_slice};
|
||||
function = std::make_shared<Function>(results, inputParams, "StridedSliceStaticShape");
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(StridedSliceZeroDimsTest, smoke_CompareWithRefs) {
|
||||
run();
|
||||
}
|
||||
|
||||
} // namespace SubgraphTestsDefinitions
|
Loading…
Reference in New Issue
Block a user