[CPU] Dynamic shapes. Gather. (#7858)

This commit is contained in:
Nikolay Shchegolev 2021-10-26 08:45:47 +03:00 committed by GitHub
parent 345c3510f3
commit 65d6010e4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 446 additions and 61 deletions

View File

@ -2,9 +2,9 @@
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include <string>
#include <mkldnn_types.h>
#include <vector>
#include "ie_parallel.hpp"
#include "mkldnn_gather_node.h"
#include <ngraph/opsets/opset1.hpp>
@ -13,21 +13,17 @@
using namespace MKLDNNPlugin;
using namespace InferenceEngine;
bool MKLDNNGatherNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
bool MKLDNNGatherNode::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
try {
if (isDynamicNgraphNode(op)) {
errorMessage = "Doesn't support op with dynamic shapes";
return false;
}
const auto gatherOp = ngraph::as_type_ptr<const ngraph::op::v7::Gather>(op);
if (!gatherOp) {
errorMessage = "Only opset7 Gather operation is supported";
if (!one_of(op->get_type_info(),
ov::op::v7::Gather::get_type_info_static())) {
errorMessage = "Not supported Gather operation version. CPU plug-in supports only 7 version.";
return false;
}
const auto axesOp = gatherOp->get_input_node_shared_ptr(GATHER_AXIS);
if (!ngraph::as_type_ptr<const ngraph::op::Constant>(axesOp)) {
errorMessage = "Only Constant operation on 'axis' input is supported";
if (op->get_input_node_shared_ptr(GATHER_AXIS)->get_type_info() != ov::op::v0::Constant::get_type_info_static()) {
// TODO: Support parameterized Axis input for dynamic shapes.
errorMessage = "Only Constant operation on 'axis' input is supported.";
return false;
}
} catch (...) {
@ -37,41 +33,37 @@ bool MKLDNNGatherNode::isSupportedOperation(const std::shared_ptr<const ngraph::
return true;
}
MKLDNNGatherNode::MKLDNNGatherNode(const std::shared_ptr<ngraph::Node>& op, const mkldnn::engine& eng,
MKLDNNGatherNode::MKLDNNGatherNode(const std::shared_ptr<ov::Node>& op, const mkldnn::engine& eng,
MKLDNNWeightsSharing::Ptr &cache) : MKLDNNNode(op, eng, cache) {
errorPrefix_ = std::string("Layer Gather with name '") + op->get_friendly_name() + "' ";
std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage;
}
errorPrefix = std::string("Layer Gather with name '") + op->get_friendly_name() + "' ";
auto gatherOp = ngraph::as_type_ptr<ngraph::op::v7::Gather>(op);
if (gatherOp->get_input_size() != 3 || gatherOp->get_output_size() != 1)
IE_THROW() << errorPrefix_ << "has incorrect number of input/output edges!";
if (op->get_input_size() != 3 || op->get_output_size() != 1)
IE_THROW() << errorPrefix << "has incorrect number of input/output edges!";
const SizeVector& srcDims = gatherOp->get_input_shape(GATHER_DATA);
const SizeVector& idxDims = gatherOp->get_input_shape(GATHER_INDEXES);
if (srcDims.size() == 0)
IE_THROW() << errorPrefix_ << "has incorrect input parameters dimension!";
dataSrcRank = inputShapes[GATHER_DATA].getRank();
const auto idxRank = inputShapes[GATHER_INDEXES].getRank();
if (dataSrcRank == 0 || idxRank == 0)
IE_THROW() << errorPrefix << "has incorrect input parameters ranks.";
axis = static_cast<int>(gatherOp->get_axis());
if (axis < 0)
axis += srcDims.size();
if (!(0 <= axis && axis < static_cast<int>(srcDims.size())))
IE_THROW() << errorPrefix_ << "has incorrect input parameters dimensions and axis number!";
batchDims = static_cast<int>(gatherOp->get_batch_dims());
batchDims = static_cast<int>(ov::as_type_ptr<ov::op::v7::Gather>(op)->get_batch_dims());
if (batchDims < 0)
batchDims += idxDims.size();
if (!(0 <= batchDims && batchDims <= std::min(static_cast<int>(srcDims.size()), static_cast<int>(idxDims.size()))) ||
batchDims > axis)
IE_THROW() << errorPrefix_ << "has incorrect batch_dims " << batchDims << "!";
batchDims += idxRank;
if (batchDims < 0 || batchDims >= std::min(static_cast<int>(dataSrcRank), static_cast<int>(idxRank)))
IE_THROW() << errorPrefix << "has incorrect batch_dims " << batchDims << "!";
for (int i = 0; i < batchDims; i++) {
if (srcDims[i] != idxDims[i])
IE_THROW() << errorPrefix_ << "has incorrect first " << batchDims << " data and indices dimensions!";
if (op->get_input_node_shared_ptr(GATHER_AXIS)->get_type_info() == ov::op::v0::Constant::get_type_info_static()) {
isAxisInputConst = true;
axis = ov::as_type<ov::op::v0::Constant>(op->get_input_node_ptr(GATHER_AXIS))->cast_vector<int>()[0];
if (axis < 0)
axis += dataSrcRank;
if (axis < 0 || axis >= dataSrcRank || batchDims > axis)
IE_THROW() << errorPrefix << "has incorrect input parameter axis value: " << axis;
}
dataSize = getOriginalInputPrecisionAtPort(GATHER_DATA).size();
}
void MKLDNNGatherNode::initSupportedPrimitiveDescriptors() {
@ -81,25 +73,29 @@ void MKLDNNGatherNode::initSupportedPrimitiveDescriptors() {
Precision dataPrecision = getOriginalInputPrecisionAtPort(GATHER_DATA);
addSupportedPrimDesc({{LayoutType::ncsp, dataPrecision},
{LayoutType::ncsp, Precision::I32},
{LayoutType::ncsp, Precision::I32}},
{LayoutType::ncsp, Precision::I32, isAxisInputConst}},
{{LayoutType::ncsp, dataPrecision}},
impl_desc_type::ref_any);
}
void MKLDNNGatherNode::createPrimitive() {
auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
auto& srcMemPtr = getParentEdgeAt(0)->getMemoryPtr();
if (!dstMemPtr || !dstMemPtr->GetPrimitivePtr())
IE_THROW() << errorPrefix_ << " has not allocated destination memory.";
void MKLDNNGatherNode::prepareParams() {
auto& srcMemPtr = getParentEdgeAt(GATHER_DATA)->getMemoryPtr();
if (!srcMemPtr || !srcMemPtr->GetPrimitivePtr())
IE_THROW() << errorPrefix_ << " has not allocated input memory.";
IE_THROW() << errorPrefix << " has not allocated input memory.";
if (getSelectedPrimitiveDescriptor() == nullptr)
IE_THROW() << errorPrefix_ << " has unidentified preferable primitive descriptor.";
IE_THROW() << errorPrefix << " has unidentified preferable primitive descriptor.";
const SizeVector srcDims = getParentEdgeAt(GATHER_DATA)->getMemory().getStaticDims();
const SizeVector idxDims = getParentEdgeAt(GATHER_INDEXES)->getMemory().getStaticDims();
const SizeVector dstDims = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims();
dataSize = getParentEdgeAt(GATHER_DATA)->getMemory().getDesc().getPrecision().size();
const auto& srcDims = srcMemPtr->getStaticDims();
const auto& idxDims = getParentEdgeAt(GATHER_INDEXES)->getMemory().getStaticDims();
const auto& dstDims = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims();
if (!isAxisInputConst) {
axis = (reinterpret_cast<const int32_t*>(getParentEdgeAt(GATHER_AXIS)->getMemoryPtr()->GetPtr()))[0];
if (axis < 0)
axis += dataSrcRank;
if (axis < 0 || axis >= dataSrcRank || batchDims > axis)
IE_THROW() << errorPrefix << "has incorrect input parameter axis value: " << axis;
}
indexRange = srcDims[axis];
batchSize = std::accumulate(srcDims.begin(), srcDims.begin() + batchDims, 1, std::multiplies<size_t>());
@ -109,9 +105,23 @@ void MKLDNNGatherNode::createPrimitive() {
idxBatchStride = std::accumulate(idxDims.begin() + batchDims, idxDims.end(), 1, std::multiplies<size_t>());
dstBatchStride = std::accumulate(dstDims.begin() + batchDims, dstDims.end(), 1, std::multiplies<size_t>());
len = dataLength * dataSize;
if (dataLength == 0)
IE_THROW() << errorPrefix_ << "had incorrect input parameters dimension!";
IE_THROW() << errorPrefix << "had incorrect input parameters dimension!";
}
bool MKLDNNGatherNode::needPrepareParams() const {
bool result = MKLDNNNode::needPrepareParams();
if (!isAxisInputConst)
result = result || axis != (reinterpret_cast<const int32_t*>(getParentEdgeAt(GATHER_AXIS)->getMemoryPtr()->GetPtr()))[0];
return result;
}
void MKLDNNGatherNode::createPrimitive() {
if (inputShapesDefined()) {
if (needPrepareParams())
prepareParams();
updateLastInputDims();
}
}
void MKLDNNGatherNode::execute(mkldnn::stream strm) {
@ -138,6 +148,10 @@ void MKLDNNGatherNode::execute(mkldnn::stream strm) {
});
}
void MKLDNNGatherNode::executeDynamicImpl(mkldnn::stream strm) {
execute(strm);
}
bool MKLDNNGatherNode::created() const {
return getType() == Gather;
}

View File

@ -4,10 +4,10 @@
#pragma once
#include <ie_common.h>
#include <mkldnn_node.h>
#include <string>
#include <memory>
#include <string>
#include <vector>
namespace MKLDNNPlugin {
@ -24,6 +24,11 @@ public:
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
protected:
void executeDynamicImpl(mkldnn::stream strm) override;
bool needPrepareParams() const override;
void prepareParams() override;
private:
int axis = 0;
int batchDims = 0;
@ -37,12 +42,14 @@ private:
size_t dstBatchStride = 1;
size_t dataSize = 1;
size_t len = 1;
int dataSrcRank = 1;
bool isAxisInputConst = false;
static const size_t GATHER_DATA = 0;
static const size_t GATHER_INDEXES = 1;
static const size_t GATHER_AXIS = 2;
static constexpr size_t GATHER_DATA = 0;
static constexpr size_t GATHER_INDEXES = 1;
static constexpr size_t GATHER_AXIS = 2;
std::string errorPrefix_;
std::string errorPrefix;
};
} // namespace MKLDNNPlugin

View File

@ -5,20 +5,24 @@
#include <vector>
#include "single_layer_tests/gather.hpp"
#include "common_test_utils/test_constants.hpp"
#include "ngraph_functions/builders.hpp"
using namespace LayerTestsDefinitions;
namespace {
const std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::I64,
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16,
InferenceEngine::Precision::BF16,
InferenceEngine::Precision::I8
};
// Just need to check types transformation.
const std::vector<InferenceEngine::Precision> netPrecisionsTrCheck = {
InferenceEngine::Precision::I64,
InferenceEngine::Precision::FP16
};
const std::vector<std::vector<size_t>> inputShapes_1D = {
std::vector<size_t>{4},
};
@ -46,12 +50,25 @@ const auto gather7Params_1D = testing::Combine(
INSTANTIATE_TEST_SUITE_P(smoke_Gather7_1D, Gather7LayerTest, gather7Params_1D, Gather7LayerTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_TypesTrf, Gather7LayerTest,
testing::Combine(
testing::ValuesIn(inputShapes_1D),
testing::ValuesIn(indicesShapes_1D),
testing::ValuesIn(axes_batchdims_1D),
testing::ValuesIn(netPrecisionsTrCheck),
testing::Values(InferenceEngine::Precision::UNSPECIFIED),
testing::Values(InferenceEngine::Precision::UNSPECIFIED),
testing::Values(InferenceEngine::Layout::ANY),
testing::Values(InferenceEngine::Layout::ANY),
testing::Values(CommonTestUtils::DEVICE_CPU)),
Gather7LayerTest::getTestCaseName);
const std::vector<std::vector<size_t>> inputShapes_2D = {
std::vector<size_t>{4, 19},
};
const std::vector<std::vector<size_t>> indicesShapes_2D = {
std::vector<size_t>{4},
std::vector<size_t>{4, 1},
std::vector<size_t>{4, 2},
};

View File

@ -69,6 +69,8 @@ std::vector<std::string> disabledTestPatterns() {
// TODO: 57562 No dynamic output shape support
R"(.*NonZeroLayerTest.*)",
// TODO: 69084 Not constant Axis input produces dynamic output shape.
R"(.*GatherLayerTestCPU.*constAx=False.*)",
// Not expected behavior
R"(.*Behavior.*InferRequestIOBBlobSetLayoutTest.*layout=(95|OIHW).*)",
R"(.*Behavior.*InferRequestIOBBlobSetLayoutTest.*layout=(95|OIHW).*)",

View File

@ -0,0 +1,345 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <shared_test_classes/single_layer/gather.hpp>
#include "ngraph_functions/builders.hpp"
#include "test_utils/cpu_test_utils.hpp"
using namespace InferenceEngine;
using namespace CPUTestUtils;
namespace CPULayerTestsDefinitions {
using inputShapesPair = std::pair<std::vector<ov::PartialShape>, std::vector<std::vector<ov::Shape>>>;
typedef std::tuple<
inputShapesPair, // Input shapes
int64_t, // Axis
int64_t, // Batch dims
InferenceEngine::Precision, // Network precision
bool, // Is axis input constant
std::string, // Device name
CPUSpecificParams // CPU specific params
> GatherLayerTestCPUParams;
class GatherLayerTestCPU : public testing::WithParamInterface<GatherLayerTestCPUParams>,
virtual public LayerTestsUtils::LayerTestsCommon, public CPUTestsBase {
public:
static std::string getTestCaseName(testing::TestParamInfo<GatherLayerTestCPUParams> obj) {
inputShapesPair inputShapes;
int axis, batchDims;
Precision netPrecision;
std::string targetDevice;
bool isAxisConstant;
CPUSpecificParams cpuParams;
std::tie(inputShapes, axis, batchDims, netPrecision, isAxisConstant, targetDevice, cpuParams) = obj.param;
std::ostringstream result;
result << "DynShapes=" << CommonTestUtils::partialShape2str(inputShapes.first) << "_";
result << "StatShapes=" << CommonTestUtils::vec2str(inputShapes.second) << "_";
result << "axis=" << axis << "_";
result << "batchDims=" << batchDims << "_";
result << "netPrc=" << netPrecision.name() << "_";
result << "constAx=" << (isAxisConstant ? "True" : "False") << "_";
result << "trgDev=" << targetDevice;
result << CPUTestsBase::getTestCaseName(cpuParams);
return result.str();
}
protected:
void SetUp() override {
inputShapesPair inputShapes;
int64_t batchDims;
Precision netPrecision;
CPUSpecificParams cpuParams;
bool isAxisConstant = true;
std::tie(inputShapes, axis, batchDims, netPrecision, isAxisConstant, targetDevice, cpuParams) = this->GetParam();
std::tie(inFmts, outFmts, priority, selectedType) = cpuParams;
selectedType = std::string("ref_any_") + netPrecision.name();
targetStaticShapes.reserve(inputShapes.second.size());
inputDynamicShapes.reserve(inputShapes.first.size());
for (int i = 0; i < (isAxisConstant ? 2 : 3); i++) {
if (inputShapes.second.size() > i)
targetStaticShapes.push_back({inputShapes.second[i]});
if (inputShapes.first.size() > i)
inputDynamicShapes.push_back(inputShapes.first[i]);
}
const ov::Shape& inputDataShape = targetStaticShapes.front().front(), indicesShape = targetStaticShapes.front()[1];
dataSrcRank = inputDataShape.size();
const auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
ov::ParameterVector functionParams {
ngraph::builder::makeParams(ngPrc, { {"data", inputDataShape} })[0],
ngraph::builder::makeParams(ov::element::i32, { {"indices", indicesShape} })[0]
};
if (!isAxisConstant) {
functionParams.push_back(ngraph::builder::makeParams(ov::element::i32, { {"axis", {1}} })[0]);
}
auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes<ov::op::v0::Parameter>(functionParams));
std::shared_ptr<ov::Node> gatherNode;
if (isAxisConstant) {
gatherNode = std::make_shared<ov::op::v8::Gather>(paramOuts[0], paramOuts[1],
ov::op::v0::Constant::create(ov::element::i64, ov::Shape({}), { axis }), batchDims);
} else {
gatherNode = std::make_shared<ov::op::v8::Gather>(paramOuts[0], paramOuts[1], paramOuts[2], batchDims);
}
ov::ResultVector results{ std::make_shared<ov::op::v0::Result>(gatherNode) };
function = std::make_shared<ov::Function>(results, functionParams, "Gather");
}
InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo &inputInfo) const override {
if (inputInfo.name() == "indices") {
const auto& td = inputInfo.getTensorDesc();
size_t normAxis = axis < 0 ? axis + dataSrcRank : axis;
const auto axDim = targetStaticShapes[index][0][normAxis];
if (axDim == 1) {
// Random generator cannot generate values in range [0; 0]
int values[1] = { 0 };
return FuncTestUtils::createAndFillBlobWithFloatArray<int32_t>(td, values, 1);
} else {
return FuncTestUtils::createAndFillBlob(td, axDim - 1, 0);
}
} else if (inputInfo.name() == "axis") {
int values[1] = { static_cast<int32_t>(axis) };
return FuncTestUtils::createAndFillBlobWithFloatArray<int32_t>(inputInfo.getTensorDesc(), values, 1);
} else {
return LayerTestsCommon::GenerateInput(inputInfo);
}
}
int64_t axis = 0;
int64_t dataSrcRank = 0;
};
TEST_P(GatherLayerTestCPU, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
Run();
CheckPluginRelatedResults(executableNetwork, "Gather");
}
namespace {
const std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::BF16,
InferenceEngine::Precision::I8
};
// 1D
const std::vector<inputShapesPair> staticInputShapes1D = {
{
{},
{ // Static shapes
{{4}, {2, 3, 4}}
}
},
{
{},
{ // Static shapes
{{4}, {1}}
}
},
{
{},
{ // Static shapes
{{4}, {9}}
}
},
{
{},
{ // Static shapes
{{5}, {5}}
}
}
};
const std::vector<inputShapesPair> dynamicInputShapes1D = {
{
{ // Origin dynamic shapes
{ov::Dimension(4, 6)}, {ov::Dimension(1, 10)}, {ov::Dimension(1, 2)}
},
{ // Dynamic shapes instances
{{4}, {1}, {1}},
{{4}, {9}, {1}},
{{5}, {5}, {1}}
}
}
};
INSTANTIATE_TEST_SUITE_P(smoke_StaticShape1D, GatherLayerTestCPU,
::testing::Combine(
::testing::ValuesIn(staticInputShapes1D),
::testing::Values(0),
::testing::Values(0),
::testing::ValuesIn(netPrecisions),
::testing::Values(true),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(CPUSpecificParams{})),
GatherLayerTestCPU::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_DynamicShape1D, GatherLayerTestCPU,
::testing::Combine(
::testing::ValuesIn(dynamicInputShapes1D),
::testing::Values(0),
::testing::Values(0),
::testing::ValuesIn(netPrecisions),
::testing::Values(true, false),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(CPUSpecificParams{})),
GatherLayerTestCPU::getTestCaseName);
// 2D
const std::vector<inputShapesPair> staticInputShapes2D = {
{
{},
{ // Static shapes
{{4, 7}, {4, 55}}
}
},
{
{},
{ // Static shapes
{{4, 17}, {4, 17}}
}
},
{
{},
{ // Static shapes
{{4, 55}, {4, 7}}
}
}
};
const std::vector<inputShapesPair> dynamicInputShapes2D = {
{
{ // Origin dynamic shapes
{4, ov::Dimension(3, 99)},
{4, ov::Dimension(3, 99)},
{1}
},
{ // Dynamic shapes instances
{{4, 7}, {4, 55}, {1}},
{{4, 55}, {4, 7}, {1}},
{{4, 17}, {4, 17}, {1}}
}
}
};
const std::vector<inputShapesPair> dynamicInputShapes2Dv2 = {
{
{ // Origin dynamic shapes
{ov::Dimension(3, 99), ov::Dimension(3, 99)},
{-1, ov::Dimension(3, 99)},
{1}
},
{ // Dynamic shapes instances
{{4, 7}, {4, 55}, {1}},
{{8, 55}, {5, 7}, {1}}
}
}
};
INSTANTIATE_TEST_SUITE_P(smoke_StaticShape2D, GatherLayerTestCPU,
::testing::Combine(
::testing::ValuesIn(staticInputShapes2D),
::testing::Values(1),
::testing::ValuesIn(std::vector<int64_t>{0, 1}),
::testing::ValuesIn(netPrecisions),
::testing::Values(true),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(CPUSpecificParams{})),
GatherLayerTestCPU::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_DynamicShape2D, GatherLayerTestCPU,
::testing::Combine(
::testing::ValuesIn(dynamicInputShapes2D),
::testing::Values(1),
::testing::ValuesIn(std::vector<int64_t>{0, 1}),
::testing::ValuesIn(netPrecisions),
::testing::Values(true, false),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(CPUSpecificParams{})),
GatherLayerTestCPU::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_DynamicShape2Dv2, GatherLayerTestCPU,
::testing::Combine(
::testing::ValuesIn(dynamicInputShapes2Dv2),
::testing::Values(0),
::testing::Values(0),
::testing::ValuesIn(netPrecisions),
::testing::Values(true, false),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(CPUSpecificParams{})),
GatherLayerTestCPU::getTestCaseName);
// 4D
const std::vector<inputShapesPair> staticInputShapes4D = {
{
{},
{ // Static shapes
{{4, 5, 6, 7}, {2, 5, 1}}
}
},
{
{},
{ // Static shapes
{{10, 5, 6, 7}, {2, 5, 2}}
}
},
{
{},
{ // Static shapes
{{16, 5, 6, 7}, {3, 5, 3}}
}
}
};
const std::vector<inputShapesPair> dynamicInputShapes4D = {
{
{ // Origin dynamic shapes
{ov::Dimension(4, 20), 5, 6, 7},
{ov::Dimension(2, 4), 5, ov::Dimension(1, 4)},
{1}
},
{ // Dynamic shapes instances
{{4, 5, 6, 7}, {2, 5, 1}, {1}},
{{10, 5, 6, 7}, {2, 5, 2}, {1}},
{{16, 5, 6, 7}, {3, 5, 3}, {1}}
}
},
{
{ // Origin dynamic shapes
{-1, -1, -1, -1}, {-1, -1, -1}, {1}
},
{ // Dynamic shapes instances
{{4, 5, 6, 4}, {2, 5, 16}, {1}},
{{10, 5, 6, 8}, {2, 5, 24}, {1}}
}
}
};
INSTANTIATE_TEST_SUITE_P(smoke_StaticShape4D, GatherLayerTestCPU,
::testing::Combine(
::testing::ValuesIn(staticInputShapes4D),
::testing::ValuesIn(std::vector<int64_t>{0, 1, 2, -1}),
::testing::Values(0),
::testing::ValuesIn(netPrecisions),
::testing::Values(true),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(CPUSpecificParams{})),
GatherLayerTestCPU::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_DynamicShape4D, GatherLayerTestCPU,
::testing::Combine(
::testing::ValuesIn(dynamicInputShapes4D),
::testing::ValuesIn(std::vector<int64_t>{0, 1, 2, -1}),
::testing::Values(0),
::testing::ValuesIn(netPrecisions),
::testing::Values(true, false),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(CPUSpecificParams{})),
GatherLayerTestCPU::getTestCaseName);
} // namespace
} // namespace CPULayerTestsDefinitions