[IE CLDNN] Gather 5d/6d support (#1553)

This commit is contained in:
Lukasz Debski 2020-08-03 09:05:53 +02:00 committed by GitHub
parent e27382070c
commit a17472fed0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 546 additions and 183 deletions

View File

@ -3760,10 +3760,15 @@ void Program::CreateGatherPrimitive(cldnn::topology& topology, InferenceEngine::
int axis = gatherLayer->GetParamAsInt("axis", 0);
// Be careful, TensorFlow consist negative axis interpretation bug. Here: -3 = b, -2 = f, -1 = y, but must be -3 = f, -2 = y, -1 = x
auto cldnnAxisFromIE = [](int axis) {
auto cldnnAxisFromIE = [](int axis, cldnn::format inputFormat) {
if (axis == 0) {
return cldnn::gather::gather_axis::along_b;
} else if (axis == 1) {
return cldnn::gather::gather_axis::along_f;
}
if (inputFormat == cldnn::format::bfyx) {
switch (axis) {
case 0: return cldnn::gather::gather_axis::along_b;
case 1: return cldnn::gather::gather_axis::along_f;
case 2: return cldnn::gather::gather_axis::along_y;
case 3: return cldnn::gather::gather_axis::along_x;
case -1: return cldnn::gather::gather_axis::along_y;
@ -3771,6 +3776,33 @@ void Program::CreateGatherPrimitive(cldnn::topology& topology, InferenceEngine::
case -3: return cldnn::gather::gather_axis::along_b;
default: THROW_CLDNN_EXCEPTION("Unsupported gather axis: " << axis);
}
} else if (inputFormat == cldnn::format::bfzyx) {
switch (axis) {
case 2: return cldnn::gather::gather_axis::along_z;
case 3: return cldnn::gather::gather_axis::along_y;
case 4: return cldnn::gather::gather_axis::along_x;
case -1: return cldnn::gather::gather_axis::along_y;
case -2: return cldnn::gather::gather_axis::along_z;
case -3: return cldnn::gather::gather_axis::along_f;
case -4: return cldnn::gather::gather_axis::along_b;
default: THROW_CLDNN_EXCEPTION("Unsupported gather axis: " << axis);
}
} else if (inputFormat == cldnn::format::bfwzyx) {
switch (axis) {
case 2: return cldnn::gather::gather_axis::along_w;
case 3: return cldnn::gather::gather_axis::along_z;
case 4: return cldnn::gather::gather_axis::along_y;
case 5: return cldnn::gather::gather_axis::along_x;
case -1: return cldnn::gather::gather_axis::along_y;
case -2: return cldnn::gather::gather_axis::along_z;
case -3: return cldnn::gather::gather_axis::along_w;
case -4: return cldnn::gather::gather_axis::along_f;
case -5: return cldnn::gather::gather_axis::along_b;
default: THROW_CLDNN_EXCEPTION("Unsupported gather axis: " << axis);
}
} else {
THROW_CLDNN_EXCEPTION("Unsupported gather axis: " << axis);
}
};
auto gatherLayerName = layer_type_name_ID(layer);
@ -3798,127 +3830,18 @@ void Program::CreateGatherPrimitive(cldnn::topology& topology, InferenceEngine::
}
}
auto indicesDims = layer->insData[1].lock()->getTensorDesc().getDims();
auto indicesLayout = layer->insData[1].lock()->getTensorDesc().getLayout();
auto indicesFormat = FormatFromLayout(indicesLayout);
auto inputDims = layer->insData[0].lock()->getTensorDesc().getDims();
auto inputLayout = layer->insData[0].lock()->getTensorDesc().getLayout();
auto inputFormat = FormatFromLayout(inputLayout);
auto outDimsOriginal = layer->outData[0]->getTensorDesc().getDims();
auto outputLayoutOriginal = layer->outData[0]->getTensorDesc().getLayout();
auto outputFormatOriginal = FormatFromLayout(outputLayoutOriginal);
auto outDims = outDimsOriginal;
auto targetDatatype = DataTypeFromPrecision(layer->precision);
auto nonNegativeAxis = (axis >= 0) ? axis : axis + 3;
// following vector is needed just to check if we can apply bfyx WA
SizeVector originalRequiredDims;
for (size_t d = 0; d < inputDims.size(); d++) {
if ((d == nonNegativeAxis) || (inputDims[d] > 1)) {
originalRequiredDims.push_back(d);
}
}
if (originalRequiredDims.size() < 4) {
// make sure that we will have at least 4 required dimensions
auto originalAxesIt = originalRequiredDims.begin();
for (size_t i = 0; i < 4; i++) {
int dimFoundAtIndex = -1;
for (size_t j = 0; j < originalRequiredDims.size(); j++) {
if (originalRequiredDims[j] == i) {
dimFoundAtIndex = j;
}
}
if (dimFoundAtIndex == -1) {
originalAxesIt = originalRequiredDims.insert(originalAxesIt, i);
}
originalAxesIt++;
}
}
// clDNN primitive is missing proper support of 5d/6d inputs
// but we can still fall back to bfyx format in some cases
bool bfyx_wa = ((inputFormat == cldnn::format::bfzyx || inputFormat == cldnn::format::bfwzyx) &&
(originalRequiredDims.size() == 4) &&
(indicesFormat == cldnn::format::bfyx));
if (bfyx_wa) {
if (indicesDims.size() > 1) {
// reshape the indices dims to 1D (along batch axis)
size_t indDimAcc = std::accumulate(indicesDims.begin(), indicesDims.end(), 1, std::multiplies<size_t>());
SizeVector targetIndDims{ indDimAcc, 1, 1, 1 };
auto reshapeName = reorderedInputs[1] + "_" + layer->name + "_reshape";
auto targetTensor = CldnnTensorFromIEDims(targetIndDims);
auto reshapePrim = cldnn::reshape(reshapeName, reorderedInputs[1], CldnnTensorFromIEDims(targetIndDims));
topology.add(reshapePrim);
AddInnerPrimitiveToProfiler(reshapeName, gatherLayerName, layer);
reorderedInputs[1] = reshapeName;
// adjust expected output dims
outDims[nonNegativeAxis] = indDimAcc;
outDims.erase(outDims.begin() + nonNegativeAxis + 1, outDims.begin() + nonNegativeAxis + indicesDims.size());
}
// reorder input to bfyx
auto reorderName = reorderedInputs[0] + "_" + layer->name + "_format_reorder";
auto reorderPrim = cldnn::reorder(reorderName, reorderedInputs[0], cldnn::format::bfyx, targetDatatype);
topology.add(reorderPrim);
AddInnerPrimitiveToProfiler(reorderName, gatherLayerName, layer);
reorderedInputs[0] = reorderName;
// calculate new input/output dims in bfyx format
SizeVector targetInDims(4);
SizeVector targetOutDims(4);
for (size_t d = 0; d < 4; d++) {
targetInDims[d] = inputDims[originalRequiredDims[d]];
targetOutDims[d] = outDims[originalRequiredDims[d]];
}
outDims = targetOutDims;
// calculate new axis in bfyx format
for (size_t d = 0; d < originalRequiredDims.size(); d++) {
if (originalRequiredDims[d] == nonNegativeAxis) {
axis = d;
}
}
// reshape the input dims to the ones expected in bfyx format
auto reshapeName = reorderedInputs[0] + "_" + layer->name + "_reshape";
auto targetTensor = CldnnTensorFromIEDims(targetInDims);
auto reshapePrim = cldnn::reshape(reshapeName, reorderedInputs[0], CldnnTensorFromIEDims(targetInDims));
topology.add(reshapePrim);
AddInnerPrimitiveToProfiler(reshapeName, gatherLayerName, layer);
reorderedInputs[0] = reshapeName;
}
auto outDims = layer->outData[0]->getTensorDesc().getDims();
auto gatherPrim = cldnn::gather(
gatherLayerName,
reorderedInputs[0],
reorderedInputs[1],
cldnnAxisFromIE(axis),
cldnnAxisFromIE(axis, FormatFromLayout(inputLayout)),
CldnnTensorFromIEDims(outDims));
topology.add(gatherPrim);
AddPrimitiveToProfiler(gatherLayerName, layer);
if (bfyx_wa) {
// reorder output back to original format
auto reorderName = gatherLayerName + "_" + layer->name + "_format_reorder";
auto reorderPrim = cldnn::reorder(reorderName, gatherPrim, outputFormatOriginal, targetDatatype);
topology.add(reorderPrim);
AddInnerPrimitiveToProfiler(reorderName, gatherLayerName, layer);
// reshape output back to original dims
auto reshapeName = gatherLayerName + "_" + layer->name + "_reshape";
auto reshapePrim = cldnn::reshape(reshapeName, reorderName, CldnnTensorFromIEDims(outDimsOriginal));
topology.add(reshapePrim);
AddInnerPrimitiveToProfiler(reshapeName, gatherLayerName, layer);
}
}
void CLDNNPlugin::Program::CreateGatherTreePrimitive(cldnn::topology & topology, InferenceEngine::CNNLayerPtr & layer) {

View File

@ -11,39 +11,348 @@ using namespace LayerTestsDefinitions;
namespace {
const std::vector<InferenceEngine::Precision> netPrecisions = {
const std::vector<InferenceEngine::Precision> netPrecisionsFP32 = {
InferenceEngine::Precision::FP32,
};
const std::vector<std::vector<size_t>> inputShapes = {
std::vector<size_t>{10, 20, 30, 40},
const std::vector<InferenceEngine::Precision> netPrecisionsI32 = {
InferenceEngine::Precision::I32,
};
const std::vector<InferenceEngine::Precision> netPrecisionsFP16 = {
InferenceEngine::Precision::FP16,
};
const std::vector<std::vector<int>> indices = {
std::vector<int>{0, 3, 2, 1},
};
const std::vector<std::vector<size_t>> indicesShapes = {
std::vector<size_t>{4}
// 5d output not supported yet
// std::vector<size_t>{2, 2}
const std::vector<std::vector<size_t>> indicesShapes12 = {
std::vector<size_t>{4},
std::vector<size_t>{2, 2}
};
const std::vector<int> axes = {0, 1, 2, 3};
const std::vector<std::vector<size_t>> indicesShapes1 = {
std::vector<size_t>{4},
};
const std::vector<std::vector<size_t>> inputShapes6DAxes5 = {
std::vector<size_t>{5, 6, 7, 8, 9, 10},
std::vector<size_t>{1, 1, 7, 8, 9, 10},
std::vector<size_t>{5, 1, 1, 8, 9, 10},
std::vector<size_t>{5, 6, 1, 1, 9, 10},
std::vector<size_t>{5, 6, 7, 1, 1, 10},
std::vector<size_t>{1, 6, 1, 8, 9, 10},
std::vector<size_t>{5, 1, 7, 1, 9, 10},
std::vector<size_t>{5, 6, 1, 8, 1, 10},
std::vector<size_t>{1, 6, 7, 1, 9, 10},
std::vector<size_t>{5, 1, 7, 8, 1, 10},
std::vector<size_t>{1, 6, 7, 8, 1, 10},
};
const auto params = testing::Combine(
const std::vector<int> axes5 = {5};
const auto Gather6dAxes5 = testing::Combine(
testing::ValuesIn(indices),
testing::ValuesIn(indicesShapes),
testing::ValuesIn(axes),
testing::ValuesIn(inputShapes),
testing::ValuesIn(netPrecisions),
testing::ValuesIn(indicesShapes1),
testing::ValuesIn(axes5),
testing::ValuesIn(inputShapes6DAxes5),
testing::ValuesIn(netPrecisionsFP32),
testing::Values(CommonTestUtils::DEVICE_GPU)
);
const std::vector<std::vector<size_t>> inputShapesAxes4 = {
std::vector<size_t>{5, 6, 7, 8, 9},
std::vector<size_t>{1, 6, 7, 8, 9},
std::vector<size_t>{5, 1, 7, 8, 9},
std::vector<size_t>{5, 6, 1, 8, 9},
std::vector<size_t>{5, 6, 7, 1, 9},
};
const std::vector<std::vector<size_t>> inputShapes6DAxes4 = {
std::vector<size_t>{5, 6, 7, 8, 9, 10},
std::vector<size_t>{1, 1, 7, 8, 9, 10},
std::vector<size_t>{5, 1, 1, 8, 9, 10},
std::vector<size_t>{5, 6, 1, 1, 9, 10},
std::vector<size_t>{5, 6, 7, 1, 9, 1},
std::vector<size_t>{1, 6, 1, 8, 9, 10},
std::vector<size_t>{5, 1, 7, 1, 9, 10},
std::vector<size_t>{5, 6, 1, 8, 9, 1},
std::vector<size_t>{1, 6, 7, 1, 9, 10},
std::vector<size_t>{5, 1, 7, 8, 9, 1},
std::vector<size_t>{1, 6, 7, 8, 9, 1},
};
const std::vector<int> axes4 = {4};
const auto GatherAxes4 = testing::Combine(
testing::ValuesIn(indices),
testing::ValuesIn(indicesShapes12),
testing::ValuesIn(axes4),
testing::ValuesIn(inputShapesAxes4),
testing::ValuesIn(netPrecisionsFP16),
testing::Values(CommonTestUtils::DEVICE_GPU)
);
INSTANTIATE_TEST_CASE_P(
Gather,
GatherAxes4,
GatherLayerTest,
params,
GatherAxes4,
GatherLayerTest::getTestCaseName
);
const auto Gather6dAxes4 = testing::Combine(
testing::ValuesIn(indices),
testing::ValuesIn(indicesShapes1),
testing::ValuesIn(axes4),
testing::ValuesIn(inputShapes6DAxes4),
testing::ValuesIn(netPrecisionsFP32),
testing::Values(CommonTestUtils::DEVICE_GPU)
);
INSTANTIATE_TEST_CASE_P(
Gather6dAxes4,
GatherLayerTest,
Gather6dAxes4,
GatherLayerTest::getTestCaseName
);
const std::vector<std::vector<size_t>> inputShapesAxes3 = {
std::vector<size_t>{5, 6, 7, 8},
std::vector<size_t>{1, 6, 7, 8},
std::vector<size_t>{5, 1, 7, 8},
std::vector<size_t>{5, 6, 1, 8},
std::vector<size_t>{5, 6, 7, 8, 9},
std::vector<size_t>{1, 6, 7, 8, 9},
std::vector<size_t>{5, 1, 7, 8, 9},
std::vector<size_t>{5, 6, 1, 8, 9},
std::vector<size_t>{5, 6, 7, 8, 1},
};
const std::vector<std::vector<size_t>> inputShapes6DAxes3 = {
std::vector<size_t>{5, 6, 7, 8, 9, 10},
std::vector<size_t>{1, 1, 7, 8, 9, 10},
std::vector<size_t>{5, 1, 1, 8, 9, 10},
std::vector<size_t>{5, 6, 1, 8, 1, 10},
std::vector<size_t>{5, 6, 7, 8, 1, 1},
std::vector<size_t>{1, 6, 1, 8, 9, 10},
std::vector<size_t>{5, 1, 7, 8, 1, 10},
std::vector<size_t>{5, 6, 1, 8, 9, 1},
std::vector<size_t>{1, 6, 7, 8, 1, 10},
std::vector<size_t>{5, 1, 7, 8, 9, 1},
std::vector<size_t>{1, 6, 7, 8, 9, 1},
};
const std::vector<int> axes3 = {3};
const auto GatherAxes3 = testing::Combine(
testing::ValuesIn(indices),
testing::ValuesIn(indicesShapes12),
testing::ValuesIn(axes3),
testing::ValuesIn(inputShapesAxes3),
testing::ValuesIn(netPrecisionsFP32),
testing::Values(CommonTestUtils::DEVICE_GPU)
);
INSTANTIATE_TEST_CASE_P(
GatherAxes3,
GatherLayerTest,
GatherAxes3,
GatherLayerTest::getTestCaseName
);
const auto Gather6dAxes3 = testing::Combine(
testing::ValuesIn(indices),
testing::ValuesIn(indicesShapes1),
testing::ValuesIn(axes3),
testing::ValuesIn(inputShapes6DAxes3),
testing::ValuesIn(netPrecisionsI32),
testing::Values(CommonTestUtils::DEVICE_GPU)
);
INSTANTIATE_TEST_CASE_P(
Gather6dAxes3,
GatherLayerTest,
Gather6dAxes3,
GatherLayerTest::getTestCaseName
);
const std::vector<std::vector<size_t>> inputShapesAxes2 = {
std::vector<size_t>{5, 6, 7, 8},
std::vector<size_t>{1, 6, 7, 8},
std::vector<size_t>{5, 1, 7, 8},
std::vector<size_t>{5, 6, 7, 1},
std::vector<size_t>{5, 6, 7, 8, 9},
std::vector<size_t>{1, 6, 7, 8, 9},
std::vector<size_t>{5, 1, 7, 8, 9},
std::vector<size_t>{5, 6, 7, 1, 9},
std::vector<size_t>{5, 6, 7, 8, 1},
};
const std::vector<std::vector<size_t>> inputShapes6DAxes2 = {
std::vector<size_t>{5, 6, 7, 8, 9, 10},
std::vector<size_t>{1, 1, 7, 8, 9, 10},
std::vector<size_t>{5, 1, 7, 1, 9, 10},
std::vector<size_t>{5, 6, 7, 1, 1, 10},
std::vector<size_t>{5, 6, 7, 8, 1, 1},
std::vector<size_t>{1, 6, 7, 1, 9, 10},
std::vector<size_t>{5, 1, 7, 8, 1, 10},
std::vector<size_t>{5, 6, 7, 1, 9, 1},
std::vector<size_t>{1, 6, 7, 8, 1, 10},
std::vector<size_t>{5, 1, 7, 8, 9, 1},
std::vector<size_t>{1, 6, 7, 8, 9, 1},
};
const std::vector<int> axes2 = {2};
const auto GatherAxes2 = testing::Combine(
testing::ValuesIn(indices),
testing::ValuesIn(indicesShapes12),
testing::ValuesIn(axes2),
testing::ValuesIn(inputShapesAxes2),
testing::ValuesIn(netPrecisionsFP32),
testing::Values(CommonTestUtils::DEVICE_GPU)
);
INSTANTIATE_TEST_CASE_P(
GatherAxes2,
GatherLayerTest,
GatherAxes2,
GatherLayerTest::getTestCaseName
);
const auto Gather6dAxes2 = testing::Combine(
testing::ValuesIn(indices),
testing::ValuesIn(indicesShapes1),
testing::ValuesIn(axes2),
testing::ValuesIn(inputShapes6DAxes2),
testing::ValuesIn(netPrecisionsFP16),
testing::Values(CommonTestUtils::DEVICE_GPU)
);
INSTANTIATE_TEST_CASE_P(
Gather6dAxes2,
GatherLayerTest,
Gather6dAxes2,
GatherLayerTest::getTestCaseName
);
const std::vector<std::vector<size_t>> inputShapesAxes1 = {
std::vector<size_t>{5, 6, 7, 8},
std::vector<size_t>{1, 6, 7, 8},
std::vector<size_t>{5, 6, 1, 8},
std::vector<size_t>{5, 6, 7, 1},
std::vector<size_t>{5, 6, 7, 8, 9},
std::vector<size_t>{1, 6, 7, 8, 9},
std::vector<size_t>{5, 6, 1, 8, 9},
std::vector<size_t>{5, 6, 7, 1, 9},
std::vector<size_t>{5, 6, 7, 8, 1},
};
const std::vector<std::vector<size_t>> inputShapes6DAxes1 = {
std::vector<size_t>{5, 6, 7, 8, 9, 10},
std::vector<size_t>{1, 6, 1, 8, 9, 10},
std::vector<size_t>{5, 6, 1, 1, 9, 10},
std::vector<size_t>{5, 6, 7, 1, 1, 10},
std::vector<size_t>{5, 6, 7, 8, 1, 1},
std::vector<size_t>{1, 6, 7, 1, 9, 10},
std::vector<size_t>{5, 6, 1, 8, 1, 10},
std::vector<size_t>{5, 6, 1, 8, 9, 1},
std::vector<size_t>{1, 6, 7, 8, 1, 10},
std::vector<size_t>{1, 6, 7, 8, 9, 1},
std::vector<size_t>{5, 6, 7, 1, 9, 1},
};
const std::vector<int> axes1 = {1};
const auto GatherAxes1 = testing::Combine(
testing::ValuesIn(indices),
testing::ValuesIn(indicesShapes12),
testing::ValuesIn(axes1),
testing::ValuesIn(inputShapesAxes1),
testing::ValuesIn(netPrecisionsI32),
testing::Values(CommonTestUtils::DEVICE_GPU)
);
INSTANTIATE_TEST_CASE_P(
GatherAxes1,
GatherLayerTest,
GatherAxes1,
GatherLayerTest::getTestCaseName
);
const auto Gather6dAxes1 = testing::Combine(
testing::ValuesIn(indices),
testing::ValuesIn(indicesShapes1),
testing::ValuesIn(axes1),
testing::ValuesIn(inputShapes6DAxes1),
testing::ValuesIn(netPrecisionsFP32),
testing::Values(CommonTestUtils::DEVICE_GPU)
);
INSTANTIATE_TEST_CASE_P(
Gather6dAxes1,
GatherLayerTest,
Gather6dAxes1,
GatherLayerTest::getTestCaseName
);
const std::vector<std::vector<size_t>> inputShapesAxes0 = {
std::vector<size_t>{5, 6, 7, 8},
std::vector<size_t>{5, 1, 7, 8},
std::vector<size_t>{5, 6, 1, 8},
std::vector<size_t>{5, 6, 7, 1},
std::vector<size_t>{5, 6, 7, 8, 9},
std::vector<size_t>{5, 1, 7, 8, 9},
std::vector<size_t>{5, 6, 1, 8, 9},
std::vector<size_t>{5, 6, 7, 1, 9},
std::vector<size_t>{5, 6, 7, 8, 1},
};
const std::vector<std::vector<size_t>> inputShapes6DAxes0 = {
std::vector<size_t>{5, 6, 7, 8, 9, 10},
std::vector<size_t>{5, 1, 1, 8, 9, 10},
std::vector<size_t>{5, 6, 1, 1, 9, 10},
std::vector<size_t>{5, 6, 7, 1, 1, 10},
std::vector<size_t>{5, 6, 7, 8, 1, 1},
std::vector<size_t>{5, 1, 7, 1, 9, 10},
std::vector<size_t>{5, 6, 1, 8, 1, 10},
std::vector<size_t>{5, 6, 1, 8, 9, 1},
std::vector<size_t>{5, 1, 7, 8, 1, 10},
std::vector<size_t>{5, 1, 7, 8, 9, 1},
std::vector<size_t>{5, 6, 7, 1, 9, 1},
};
const std::vector<int> axes0 = {0};
const auto GatherAxes0 = testing::Combine(
testing::ValuesIn(indices),
testing::ValuesIn(indicesShapes12),
testing::ValuesIn(axes0),
testing::ValuesIn(inputShapesAxes0),
testing::ValuesIn(netPrecisionsFP32),
testing::Values(CommonTestUtils::DEVICE_GPU)
);
INSTANTIATE_TEST_CASE_P(
GatherAxes0,
GatherLayerTest,
GatherAxes0,
GatherLayerTest::getTestCaseName
);
const auto Gather6dAxes0 = testing::Combine(
testing::ValuesIn(indices),
testing::ValuesIn(indicesShapes1),
testing::ValuesIn(axes0),
testing::ValuesIn(inputShapes6DAxes0),
testing::ValuesIn(netPrecisionsFP32),
testing::Values(CommonTestUtils::DEVICE_GPU)
);
INSTANTIATE_TEST_CASE_P(
Gather6dAxes0,
GatherLayerTest,
Gather6dAxes0,
GatherLayerTest::getTestCaseName
);

View File

@ -7,18 +7,18 @@
INSTANTIATE_TEST_CASE_P(
smoke_GPU_TestsGather, GatherTFTests,
::testing::Values(
gatherTF_test_params{ "GPU", "FP32", { 1, 4 }, in0,{ 2, 2 }, dict2D, 0, { 1, 4, 2 }, ref_in0_a0_d22 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 2, 3 }, dict, 0, { 2, 2, 2, 3 }, ref_in0_a0_d223 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 2, 3 }, dict,-3, { 2, 2, 2, 3 }, ref_in0_a0_d223 },
gatherTF_test_params{ "GPU", "FP32", { 1, 4 }, in0,{ 2, 2 }, dict2D, 0, { 1, 4, 1, 2 }, ref_in0_a0_d22 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 2, 3 }, dict, 0, { 2, 2, 1, 2, 3 }, ref_in0_a0_d223 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 2, 3 }, dict,-3, { 2, 2, 1, 2, 3 }, ref_in0_a0_d223 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 3, 2, 2 }, dict, 0, { 2, 2, 2, 2 }, ref_in1_a0_d322 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 3, 2, 2 }, dict,-3, { 2, 2, 2, 2 }, ref_in1_a0_d322 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 3, 2 }, dict, 1, { 2, 2, 2, 2 }, ref_in1_a1_d232 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 3, 2 }, dict,-2, { 2, 2, 2, 2 }, ref_in1_a1_d232 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 3, 2, 2 }, dict, 0, { 2, 2, 1, 2, 2 }, ref_in1_a0_d322 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 3, 2, 2 }, dict,-3, { 2, 2, 1, 2, 2 }, ref_in1_a0_d322 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 3, 2 }, dict, 1, { 2, 2, 2, 2, 1 }, ref_in1_a1_d232 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 3, 2 }, dict,-2, { 2, 2, 2, 2, 1 }, ref_in1_a1_d232 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 2, 3 }, dict, 2, { 2, 2, 2, 2 }, ref_in1_a2_d223 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 2, 3 }, dict,-1, { 2, 2, 2, 2 }, ref_in1_a2_d223 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 3, 2 }, dict, 2, { 2, 3, 2, 2 }, ref_in0_a2_d232 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 3, 2 }, dict,-1, { 2, 3, 2, 2 }, ref_in0_a2_d232 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 3, 2 }, dict, 2, { 2, 3, 2, 2 }, ref_in0_a2_d232 }
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 2, 3 }, dict, 2, { 2, 2, 2, 2, 1 }, ref_in1_a2_d223 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in1,{ 2, 2, 3 }, dict,-1, { 2, 2, 2, 2, 1 }, ref_in1_a2_d223 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 3, 2 }, dict, 2, { 2, 3, 2, 2, 1 }, ref_in0_a2_d232 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 3, 2 }, dict,-1, { 2, 3, 2, 2, 1 }, ref_in0_a2_d232 },
gatherTF_test_params{ "GPU", "FP32", { 2, 2 }, in0,{ 2, 3, 2 }, dict, 2, { 2, 3, 2, 2, 1 }, ref_in0_a2_d232 }
));

View File

@ -35,7 +35,9 @@ struct gather : public primitive_base<gather> {
along_b,
along_f,
along_x,
along_y
along_y,
along_z,
along_w
};
/// @brief Constructs gather primitive.

View File

@ -480,6 +480,8 @@ enum class ContractMode {
enum class GatherAxis {
X,
Y,
Z,
W,
FEATURE,
BATCH,
};

View File

@ -23,10 +23,16 @@ namespace kernel_selector {
static size_t GetGatherChannelIndex(const gather_params& params) {
Tensor::DataChannelName name = Tensor::DataChannelName::X;
size_t inputSize = params.inputs[0].GetDims().size();
switch (params.axis) {
case GatherAxis::X:
return 3;
return inputSize - 1;
case GatherAxis::Y:
return inputSize - 2;
case GatherAxis::Z:
return inputSize - 3;
case GatherAxis::W:
return 2;
case GatherAxis::FEATURE:
return 1;
@ -51,6 +57,10 @@ ParamsKey GatherKernelRef::GetSupportedKey() const {
k.EnableOutputDataType(Datatype::UINT8);
k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);
k.EnableInputLayout(DataLayout::bfzyx);
k.EnableOutputLayout(DataLayout::bfzyx);
k.EnableInputLayout(DataLayout::bfwzyx);
k.EnableOutputLayout(DataLayout::bfwzyx);
k.EnableTensorOffset();
k.EnableTensorPitches();
k.EnableBatching();
@ -82,8 +92,22 @@ static inline std::string GetOrderString(std::vector<std::string>& order) {
return order_str;
}
static inline std::vector<std::string> GetOrder(size_t size) {
std::vector<std::string> idx_order;
if (size <= 4) {
idx_order = {"b", "f", "y", "x"};
} else if (size == 5) {
idx_order = {"b", "f", "z", "y", "x"};
} else if (size == 6) {
idx_order = {"b", "f", "w", "z", "y", "x"};
}
return idx_order;
}
static std::string GetDictionaryIndexOrder(const gather_params& params, size_t axis) {
std::vector<std::string> default_order = { "b", "f", "y", "x" };
std::vector<std::string> idx_order = GetOrder(params.output.GetDims().size());
const std::string input_axis_index_macro = "INPUT_AXIS_INDEX";
const std::string zeroVal = "0";
@ -92,38 +116,57 @@ static std::string GetDictionaryIndexOrder(const gather_params& params, size_t a
// Shift indices of Gather dictionary input related to output dims
for (size_t i = axis + 1; i < dictionary_dims_num; i++)
default_order[i] = default_order[i + indices_dims_num - 1];
idx_order[i] = idx_order[i + indices_dims_num - 1];
for (size_t i = dictionary_dims_num; i < default_order.size(); i++)
default_order[i] = zeroVal;
for (size_t i = dictionary_dims_num; i < idx_order.size(); i++)
idx_order[i] = zeroVal;
default_order[axis] = input_axis_index_macro;
// Fix size to inputs[0] dims size
for (size_t i = 0; i < params.output.GetDims().size() - params.inputs[0].GetDims().size(); i++)
idx_order.pop_back();
return GetOrderString(default_order);
idx_order[axis] = input_axis_index_macro;
return GetOrderString(idx_order);
}
static std::string GetIndecesIdxOrder(const gather_params& params, size_t axis) {
std::vector<std::string> default_order = { "b", "f", "y", "x" };
std::vector<std::string> idx_order = GetOrder(params.output.GetDims().size());
const std::string zero_val = "0";
size_t indices_dims_num = GetNonEmptyDimsNumber(params.inputs[1]);
// Shift indices of Gather indices input related to output dims
for (size_t i = 0; i < indices_dims_num; i++)
default_order[i] = default_order[axis + i];
idx_order[i] = idx_order[axis + i];
for (size_t i = indices_dims_num; i < default_order.size(); i++)
default_order[i] = zero_val;
for (size_t i = indices_dims_num; i < idx_order.size(); i++)
idx_order[i] = zero_val;
return GetOrderString(default_order);
// Fix size to inputs[1] dims size
for (size_t i = 0; i < params.output.GetDims().size() - params.inputs[1].GetDims().size(); i++)
idx_order.pop_back();
return GetOrderString(idx_order);
}
CommonDispatchData GatherKernelRef::SetDefault(const gather_params& params, const optional_params&) const {
CommonDispatchData runInfo;
const auto& output = params.output;
std::vector<size_t> global = {output.Batch().v, output.Feature().v,output.X().v * output.Y().v};
std::vector<size_t> local = GetOptimalLocalWorkGroupSizes(global, params.engineInfo);
std::vector<size_t> global;
std::vector<size_t> local;
if (output.GetLayout() == DataLayout::bfyx) {
global = {output.X().v, output.Y().v, output.Feature().v * output.Batch().v};
} else if (output.GetLayout() == DataLayout::bfzyx) {
global = {output.X().v, output.Y().v * output.Z().v, output.Feature().v * output.Batch().v};
} else {
global = {output.X().v * output.Y().v, output.Z().v * output.W().v, output.Feature().v * output.Batch().v};
}
local = GetOptimalLocalWorkGroupSizes(global, params.engineInfo);
runInfo.gws0 = global[0];
runInfo.gws1 = global[1];
@ -145,7 +188,9 @@ JitConstants GatherKernelRef::GetJitConstants(const gather_params& params) const
jit.AddConstant(MakeJitConstant("INDICES_INDEX_ORDER", GetIndecesIdxOrder(params, GetGatherChannelIndex(params))));
if (!params.fused_ops.empty()) {
FusedOpsConfiguration conf = { "", {"b", "f", "y", "x"}, "val", params.inputs[0].GetDType() };
std::vector<std::string> idx_order = GetOrder(params.inputs[0].GetDims().size());
FusedOpsConfiguration conf = { "", idx_order, "val", params.inputs[0].GetDType() };
jit.Merge(MakeFusedOpsJitConstants(params, {conf}));
}

View File

@ -18,6 +18,7 @@
#define INPUT_AXIS_INDEX (uint)indices[indices_idx]
#define GET_DICTIONARY_INDEX(idx_order) INPUT0_GET_INDEX(idx_order)
#define GET_INDICES_INDEX(idx_order) INPUT1_GET_INDEX(idx_order)
#define GET_INDEX(prefix, num, idx_order) CAT(CAT(prefix, num), _GET_INDEX)(idx_order)
KERNEL(gather_ref)(const __global INPUT0_TYPE* dictionary,
const __global INPUT1_TYPE* indices,
@ -27,15 +28,32 @@ KERNEL(gather_ref)(const __global INPUT0_TYPE* dictionary,
#endif
)
{
const uint b = get_global_id(0);
const uint f = get_global_id(1);
const uint yx = get_global_id(2);
const uint y = yx / OUTPUT_SIZE_X;
const uint x = yx % OUTPUT_SIZE_X;
#if OUTPUT_DIMS == 6
#define ORDER b,f,w,z,y,x
const uint x = (uint)get_global_id(0) % OUTPUT_SIZE_X;
const uint y = (uint)get_global_id(0) / OUTPUT_SIZE_X;
const uint z = (uint)get_global_id(1) % OUTPUT_SIZE_Z;
const uint w = (uint)get_global_id(1) / OUTPUT_SIZE_Z;
const uint f = (uint)get_global_id(2) % OUTPUT_FEATURE_NUM;
const uint b = (uint)get_global_id(2) / OUTPUT_FEATURE_NUM;
#elif OUTPUT_DIMS == 5
#define ORDER b,f,z,y,x
const uint x = (uint)get_global_id(0);
const uint y = (uint)get_global_id(1) % OUTPUT_SIZE_Y;
const uint z = (uint)get_global_id(1) / OUTPUT_SIZE_Y;
const uint f = (uint)get_global_id(2) % OUTPUT_FEATURE_NUM;
const uint b = (uint)get_global_id(2) / OUTPUT_FEATURE_NUM;
#elif OUTPUT_DIMS == 4
#define ORDER b,f,y,x
const uint x = (uint)get_global_id(0);
const uint y = (uint)get_global_id(1);
const uint f = (uint)get_global_id(2) % OUTPUT_FEATURE_NUM;
const uint b = (uint)get_global_id(2) / OUTPUT_FEATURE_NUM;
#endif
const uint indices_idx = GET_INDICES_INDEX(INDICES_INDEX_ORDER);
const uint dictionary_idx = GET_DICTIONARY_INDEX(DICTIONARY_INDEX_ORDER);
const uint output_idx = OUTPUT_GET_INDEX(b, f, y, x);
const uint output_idx = GET_INDEX(OUTPUT,,ORDER);
INPUT0_TYPE val = dictionary[dictionary_idx];

View File

@ -31,16 +31,28 @@ layout gather_inst::calc_output_layout(gather_node const& node) {
auto desc = node.get_primitive();
auto input_layout = node.input(0).get_output_layout();
auto input_format = input_layout.format;
auto output_shape = desc->output_shape;
auto output_format = input_layout.format;
int spatialNum = 0;
for (auto i : node.input(1).get_output_layout().size.raw)
spatialNum += (i > 1) ? 1 : 0;
// change output format if input indeces > 1
if (spatialNum == 2 && output_format == cldnn::format::bfzyx) {
output_format = cldnn::format::bfwzyx;
} else if (spatialNum == 2 && output_format == cldnn::format::bfyx) {
output_format = cldnn::format::bfzyx;
} else if (spatialNum == 3 && output_format == cldnn::format::bfyx) {
output_format = cldnn::format::bfwzyx;
}
auto output_type = input_layout.data_type;
if (node.has_fused_primitives()) {
output_type = node.get_fused_output_layout().data_type;
}
return layout{output_type, input_format, output_shape};
return layout{output_type, output_format, output_shape};
}
std::string gather_inst::to_string(gather_node const& node) {

View File

@ -32,6 +32,10 @@ kernel_selector::gather_axis convert_axis(gather::gather_axis axis) {
return kernel_selector::gather_axis::X;
case gather::along_y:
return kernel_selector::gather_axis::Y;
case gather::along_z:
return kernel_selector::gather_axis::Z;
case gather::along_w:
return kernel_selector::gather_axis::W;
case gather::along_f:
return kernel_selector::gather_axis::FEATURE;
case gather::along_b:
@ -76,6 +80,14 @@ attach_gather_gpu::attach_gather_gpu() {
implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw);
implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw);
implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfyx), val_fw);
implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfzyx), val_fw);
implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfzyx), val_fw);
implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfzyx), val_fw);
implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfwzyx), val_fw);
implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfwzyx), val_fw);
implementation_map<gather>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfwzyx), val_fw);
}
} // namespace detail

View File

@ -4959,6 +4959,18 @@ struct gather_test_params {
#define CASE_GATHER_FP16_4 {5, 3, 2, 2}, {3, 1, 1, 1}, {5, 2, 2, 3}, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GATHER_FP16_5 {2, 3, 1, 2}, {1, 3, 1, 1}, {2, 3, 3, 1}, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GATHER_5D_FP32_1 {2, 3, 1, 4, 1}, {4, 1, 1, 1}, {4, 3, 1, 4, 1}, cldnn::gather::gather_axis::along_b, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
#define CASE_GATHER_5D_FP32_2 {2, 3, 2, 2, 2}, {2, 1, 1, 1}, {2, 2, 2, 2, 2}, cldnn::gather::gather_axis::along_f, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
#define CASE_GATHER_5D_FP32_3 {5, 3, 2, 2, 2}, {3, 1, 1, 1}, {5, 3, 2, 3, 2}, cldnn::gather::gather_axis::along_y, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
#define CASE_GATHER_5D_FP32_4 {2, 3, 1, 4, 4}, {2, 1, 1, 1}, {2, 3, 1, 4, 2}, cldnn::gather::gather_axis::along_z, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
#define CASE_GATHER_5D_FP32_5 {3, 1, 5, 2, 1}, {2, 1, 1, 1}, {3, 1, 2, 2, 1}, cldnn::gather::gather_axis::along_x, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
#define CASE_GATHER_5D_FP16_1 {3, 2, 1, 2, 1}, {2, 1, 1, 1}, {2, 2, 2, 2, 1}, cldnn::gather::gather_axis::along_b, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
#define CASE_GATHER_5D_FP16_2 {1, 3, 1, 2, 1}, {2, 1, 1, 1}, {1, 2, 1, 2, 1}, cldnn::gather::gather_axis::along_f, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
#define CASE_GATHER_5D_FP16_3 {2, 3, 1, 3, 3}, {1, 2, 1, 1}, {2, 3, 1, 2, 3}, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
#define CASE_GATHER_5D_FP16_4 {3, 2, 2, 2, 2}, {2, 1, 1, 1}, {3, 2, 2, 2, 2}, cldnn::gather::gather_axis::along_z, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
#define CASE_GATHER_5D_FP16_5 {1, 1, 2, 1, 1}, {3, 1, 1, 1}, {1, 1, 3, 1, 1}, cldnn::gather::gather_axis::along_x, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
class GatherPrimitiveFusingTest : public ::BaseFusingTest<gather_test_params> {
public:
void execute(gather_test_params& p) {
@ -4985,6 +4997,10 @@ public:
return p.dictionary_shape.spatial[0];
case cldnn::gather::gather_axis::along_y:
return p.dictionary_shape.spatial[1];
case cldnn::gather::gather_axis::along_z:
return p.dictionary_shape.spatial[2];
case cldnn::gather::gather_axis::along_w:
return p.dictionary_shape.spatial[3];
case cldnn::gather::gather_axis::along_f:
return p.dictionary_shape.feature[0];
case cldnn::gather::gather_axis::along_b:
@ -5030,6 +5046,18 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_quantize,
gather_test_params{ CASE_GATHER_FP16_3, 2, 3 },
gather_test_params{ CASE_GATHER_FP16_4, 2, 3 },
gather_test_params{ CASE_GATHER_FP16_5, 2, 3 },
gather_test_params{ CASE_GATHER_5D_FP32_1, 2, 3 },
gather_test_params{ CASE_GATHER_5D_FP32_2, 2, 3 },
gather_test_params{ CASE_GATHER_5D_FP32_3, 2, 3 },
gather_test_params{ CASE_GATHER_5D_FP32_4, 2, 3 },
gather_test_params{ CASE_GATHER_5D_FP32_5, 2, 3 },
gather_test_params{ CASE_GATHER_5D_FP16_1, 2, 3 },
gather_test_params{ CASE_GATHER_5D_FP16_2, 2, 3 },
gather_test_params{ CASE_GATHER_5D_FP16_3, 2, 3 },
gather_test_params{ CASE_GATHER_5D_FP16_4, 2, 3 },
gather_test_params{ CASE_GATHER_5D_FP16_5, 2, 3 },
}), );
class gather_scale_activation : public GatherPrimitiveFusingTest {};
@ -5061,6 +5089,18 @@ INSTANTIATE_TEST_CASE_P(fusings_gpu, gather_scale_activation,
gather_test_params{ CASE_GATHER_FP16_3, 2, 4 },
gather_test_params{ CASE_GATHER_FP16_4, 2, 4 },
gather_test_params{ CASE_GATHER_FP16_5, 2, 4 },
gather_test_params{ CASE_GATHER_5D_FP32_1, 2, 4 },
gather_test_params{ CASE_GATHER_5D_FP32_2, 2, 4 },
gather_test_params{ CASE_GATHER_5D_FP32_3, 2, 4 },
gather_test_params{ CASE_GATHER_5D_FP32_4, 2, 4 },
gather_test_params{ CASE_GATHER_5D_FP32_5, 2, 4 },
gather_test_params{ CASE_GATHER_5D_FP16_1, 2, 4 },
gather_test_params{ CASE_GATHER_5D_FP16_2, 2, 4 },
gather_test_params{ CASE_GATHER_5D_FP16_3, 2, 4 },
gather_test_params{ CASE_GATHER_5D_FP16_4, 2, 4 },
gather_test_params{ CASE_GATHER_5D_FP16_5, 2, 4 },
}), );
/* ------------------------------------------------------------------------------------------------------------ */

View File

@ -125,7 +125,7 @@ TEST(gather_gpu_fp16, d222_axisB) {
topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout()));
topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
);
network network(engine, topology);
@ -186,7 +186,7 @@ TEST(gather_gpu_fp16, d22_axisY) {
topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout()));
topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
);
network network(engine, topology);
@ -247,7 +247,7 @@ TEST(gather_gpu_fp16, d22_axisF) {
topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout()));
topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 1, 2))
);
network network(engine, topology);
@ -366,7 +366,7 @@ TEST(gather_gpu_fp32, d222_axisB) {
topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout()));
topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
);
network network(engine, topology);
@ -427,7 +427,7 @@ TEST(gather_gpu_fp32, d22_axisY) {
topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout()));
topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
);
network network(engine, topology);
@ -488,7 +488,7 @@ TEST(gather_gpu_fp32, d22_axisF) {
topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout()));
topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 1, 2))
);
network network(engine, topology);
@ -549,7 +549,7 @@ TEST(gather_gpu_int32, d22_axisF) {
topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout()));
topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 1, 2))
);
network network(engine, topology);
@ -668,7 +668,7 @@ TEST(gather_gpu_int32, d222_axisB) {
topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout()));
topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
);
network network(engine, topology);
@ -729,7 +729,7 @@ TEST(gather_gpu_int32, d22_axisY) {
topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout()));
topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 2))
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
);
network network(engine, topology);