[IE CLDNN] Gather fix (#3273)

This commit is contained in:
Sergey Shlyapnikov 2020-11-23 16:40:09 +03:00 committed by GitHub
parent 9b85d67cce
commit 5de04e9b94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 102 additions and 54 deletions

View File

@ -4302,12 +4302,14 @@ void Program::CreateGatherPrimitive(cldnn::topology& topology, InferenceEngine::
auto inputLayout = layer->insData[0].lock()->getTensorDesc().getLayout(); auto inputLayout = layer->insData[0].lock()->getTensorDesc().getLayout();
auto outDims = layer->outData[0]->getTensorDesc().getDims(); auto outDims = layer->outData[0]->getTensorDesc().getDims();
auto outLayout = layer->outData[0]->getTensorDesc().getLayout();
auto gatherPrim = cldnn::gather( auto gatherPrim = cldnn::gather(
gatherLayerName, gatherLayerName,
reorderedInputs[0], reorderedInputs[0],
reorderedInputs[1], reorderedInputs[1],
cldnnAxisFromIE(axis, FormatFromLayout(inputLayout)), cldnnAxisFromIE(axis, FormatFromLayout(inputLayout)),
FormatFromLayout(outLayout),
CldnnTensorFromIEDims(outDims)); CldnnTensorFromIEDims(outDims));
topology.add(gatherPrim); topology.add(gatherPrim);

View File

@ -197,6 +197,7 @@ INSTANTIATE_TEST_CASE_P(
); );
const std::vector<std::vector<size_t>> inputShapesAxes2 = { const std::vector<std::vector<size_t>> inputShapesAxes2 = {
std::vector<size_t>{5, 6, 7},
std::vector<size_t>{5, 6, 7, 8}, std::vector<size_t>{5, 6, 7, 8},
std::vector<size_t>{1, 6, 7, 8}, std::vector<size_t>{1, 6, 7, 8},
std::vector<size_t>{5, 1, 7, 8}, std::vector<size_t>{5, 1, 7, 8},
@ -265,6 +266,8 @@ INSTANTIATE_TEST_CASE_P(
); );
const std::vector<std::vector<size_t>> inputShapesAxes1 = { const std::vector<std::vector<size_t>> inputShapesAxes1 = {
std::vector<size_t>{5, 6},
std::vector<size_t>{5, 6, 7},
std::vector<size_t>{5, 6, 7, 8}, std::vector<size_t>{5, 6, 7, 8},
std::vector<size_t>{1, 6, 7, 8}, std::vector<size_t>{1, 6, 7, 8},
std::vector<size_t>{5, 6, 1, 8}, std::vector<size_t>{5, 6, 1, 8},
@ -333,6 +336,9 @@ INSTANTIATE_TEST_CASE_P(
); );
const std::vector<std::vector<size_t>> inputShapesAxes0 = { const std::vector<std::vector<size_t>> inputShapesAxes0 = {
std::vector<size_t>{5},
std::vector<size_t>{5, 6},
std::vector<size_t>{5, 6, 7},
std::vector<size_t>{5, 6, 7, 8}, std::vector<size_t>{5, 6, 7, 8},
std::vector<size_t>{5, 1, 7, 8}, std::vector<size_t>{5, 1, 7, 8},
std::vector<size_t>{5, 6, 1, 8}, std::vector<size_t>{5, 6, 1, 8},

View File

@ -1,5 +1,5 @@
/* /*
// Copyright (c) 2019 Intel Corporation // Copyright (c) 2019-2020 Intel Corporation
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -50,13 +50,16 @@ struct gather : public primitive_base<gather> {
const primitive_id& dict, const primitive_id& dict,
const primitive_id& idx, const primitive_id& idx,
const gather_axis axis, const gather_axis axis,
const format& output_format,
const tensor& output_shape, const tensor& output_shape,
const padding& output_padding = padding()) const padding& output_padding = padding())
: primitive_base(id, {dict, idx}, output_padding), axis(axis), output_shape(output_shape) {} : primitive_base(id, {dict, idx}, output_padding), axis(axis), output_format(output_format), output_shape(output_shape) {}
/// @brief Gathering axis /// @brief Gathering axis
gather_axis axis; gather_axis axis;
/// @brief Gathering input shape /// @brief Gather output format
format output_format;
/// @brief Gather output shape
tensor output_shape; tensor output_shape;
}; };
/// @} /// @}

View File

@ -31,21 +31,8 @@ layout gather_inst::calc_output_layout(gather_node const& node) {
auto desc = node.get_primitive(); auto desc = node.get_primitive();
auto input_layout = node.input(0).get_output_layout(); auto input_layout = node.input(0).get_output_layout();
auto output_format = desc->output_format;
auto output_shape = desc->output_shape; auto output_shape = desc->output_shape;
auto output_format = input_layout.format;
int spatialNum = 0;
for (auto i : node.input(1).get_output_layout().size.raw)
spatialNum += (i > 1) ? 1 : 0;
// change output format if input indeces > 1
if (spatialNum == 2 && output_format == cldnn::format::bfzyx) {
output_format = cldnn::format::bfwzyx;
} else if (spatialNum == 2 && output_format == cldnn::format::bfyx) {
output_format = cldnn::format::bfzyx;
} else if (spatialNum == 3 && output_format == cldnn::format::bfyx) {
output_format = cldnn::format::bfwzyx;
}
auto output_type = input_layout.data_type; auto output_type = input_layout.data_type;
if (node.has_fused_primitives()) { if (node.has_fused_primitives()) {

View File

@ -5312,6 +5312,7 @@ struct gather_test_params {
tensor dictionary_shape; tensor dictionary_shape;
tensor indices_shape; tensor indices_shape;
tensor out_shape; tensor out_shape;
format out_format;
cldnn::gather::gather_axis axis; cldnn::gather::gather_axis axis;
data_types data_type; data_types data_type;
format input_format; format input_format;
@ -5321,29 +5322,29 @@ struct gather_test_params {
size_t expected_not_fused_primitives; size_t expected_not_fused_primitives;
}; };
#define CASE_GATHER_FP32_1 {2, 3, 1, 4}, {4, 1, 1, 1}, {4, 3, 1, 4}, cldnn::gather::gather_axis::along_b, data_types::f32, format::bfyx, data_types::f32, format::bfyx #define CASE_GATHER_FP32_1 {2, 3, 1, 4}, {4, 1, 1, 1}, {4, 3, 1, 4}, format::bfyx, cldnn::gather::gather_axis::along_b, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GATHER_FP32_2 {3, 2, 1, 2}, {2, 3, 1, 1}, {2, 3, 2, 2}, cldnn::gather::gather_axis::along_b, data_types::f32, format::bfyx, data_types::f32, format::bfyx #define CASE_GATHER_FP32_2 {3, 2, 1, 2}, {2, 3, 1, 1}, {2, 3, 2, 2}, format::bfyx, cldnn::gather::gather_axis::along_b, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GATHER_FP32_3 {3, 1, 1, 2}, {2, 1, 1, 1}, {3, 2, 1, 2}, cldnn::gather::gather_axis::along_f, data_types::f32, format::bfyx, data_types::f32, format::bfyx #define CASE_GATHER_FP32_3 {3, 1, 1, 2}, {2, 1, 1, 1}, {3, 2, 1, 2}, format::bfyx, cldnn::gather::gather_axis::along_f, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GATHER_FP32_4 {5, 3, 2, 2}, {3, 1, 1, 1}, {5, 2, 2, 3}, cldnn::gather::gather_axis::along_y, data_types::f32, format::bfyx, data_types::f32, format::bfyx #define CASE_GATHER_FP32_4 {5, 3, 2, 2}, {3, 1, 1, 1}, {5, 2, 2, 3}, format::bfyx, cldnn::gather::gather_axis::along_y, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GATHER_FP32_5 {2, 3, 1, 2}, {1, 3, 1, 1}, {2, 3, 3, 1}, cldnn::gather::gather_axis::along_y, data_types::f32, format::bfyx, data_types::f32, format::bfyx #define CASE_GATHER_FP32_5 {2, 3, 1, 2}, {1, 3, 1, 1}, {2, 3, 3, 1}, format::bfyx, cldnn::gather::gather_axis::along_y, data_types::f32, format::bfyx, data_types::f32, format::bfyx
#define CASE_GATHER_FP16_1 {2, 3, 1, 4}, {4, 1, 1, 1}, {4, 3, 1, 4}, cldnn::gather::gather_axis::along_b, data_types::f16, format::bfyx, data_types::f16, format::bfyx #define CASE_GATHER_FP16_1 {2, 3, 1, 4}, {4, 1, 1, 1}, {4, 3, 1, 4}, format::bfyx, cldnn::gather::gather_axis::along_b, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GATHER_FP16_2 {3, 2, 1, 2}, {2, 3, 1, 1}, {2, 3, 2, 2}, cldnn::gather::gather_axis::along_b, data_types::f16, format::bfyx, data_types::f16, format::bfyx #define CASE_GATHER_FP16_2 {3, 2, 1, 2}, {2, 3, 1, 1}, {2, 3, 2, 2}, format::bfyx, cldnn::gather::gather_axis::along_b, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GATHER_FP16_3 {3, 1, 1, 2}, {2, 1, 1, 1}, {3, 2, 1, 2}, cldnn::gather::gather_axis::along_f, data_types::f16, format::bfyx, data_types::f16, format::bfyx #define CASE_GATHER_FP16_3 {3, 1, 1, 2}, {2, 1, 1, 1}, {3, 2, 1, 2}, format::bfyx, cldnn::gather::gather_axis::along_f, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GATHER_FP16_4 {5, 3, 2, 2}, {3, 1, 1, 1}, {5, 2, 2, 3}, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfyx, data_types::f16, format::bfyx #define CASE_GATHER_FP16_4 {5, 3, 2, 2}, {3, 1, 1, 1}, {5, 2, 2, 3}, format::bfyx, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GATHER_FP16_5 {2, 3, 1, 2}, {1, 3, 1, 1}, {2, 3, 3, 1}, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfyx, data_types::f16, format::bfyx #define CASE_GATHER_FP16_5 {2, 3, 1, 2}, {1, 3, 1, 1}, {2, 3, 3, 1}, format::bfyx, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GATHER_5D_FP32_1 {2, 3, 1, 4, 1}, {4, 1, 1, 1}, {4, 3, 1, 4, 1}, cldnn::gather::gather_axis::along_b, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx #define CASE_GATHER_5D_FP32_1 {2, 3, 1, 4, 1}, {4, 1, 1, 1}, {4, 3, 1, 4, 1}, format::bfzyx, cldnn::gather::gather_axis::along_b, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
#define CASE_GATHER_5D_FP32_2 {2, 3, 2, 2, 2}, {2, 1, 1, 1}, {2, 2, 2, 2, 2}, cldnn::gather::gather_axis::along_f, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx #define CASE_GATHER_5D_FP32_2 {2, 3, 2, 2, 2}, {2, 1, 1, 1}, {2, 2, 2, 2, 2}, format::bfzyx, cldnn::gather::gather_axis::along_f, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
#define CASE_GATHER_5D_FP32_3 {5, 3, 2, 2, 2}, {3, 1, 1, 1}, {5, 3, 2, 3, 2}, cldnn::gather::gather_axis::along_y, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx #define CASE_GATHER_5D_FP32_3 {5, 3, 2, 2, 2}, {3, 1, 1, 1}, {5, 3, 2, 3, 2}, format::bfzyx, cldnn::gather::gather_axis::along_y, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
#define CASE_GATHER_5D_FP32_4 {2, 3, 1, 4, 4}, {2, 1, 1, 1}, {2, 3, 1, 4, 2}, cldnn::gather::gather_axis::along_z, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx #define CASE_GATHER_5D_FP32_4 {2, 3, 1, 4, 4}, {2, 1, 1, 1}, {2, 3, 1, 4, 2}, format::bfzyx, cldnn::gather::gather_axis::along_z, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
#define CASE_GATHER_5D_FP32_5 {3, 1, 5, 2, 1}, {2, 1, 1, 1}, {3, 1, 2, 2, 1}, cldnn::gather::gather_axis::along_x, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx #define CASE_GATHER_5D_FP32_5 {3, 1, 5, 2, 1}, {2, 1, 1, 1}, {3, 1, 2, 2, 1}, format::bfzyx, cldnn::gather::gather_axis::along_x, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
#define CASE_GATHER_5D_FP16_1 {3, 2, 1, 2, 1}, {2, 1, 1, 1}, {2, 2, 2, 2, 1}, cldnn::gather::gather_axis::along_b, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx #define CASE_GATHER_5D_FP16_1 {3, 2, 1, 2, 1}, {2, 1, 1, 1}, {2, 2, 2, 2, 1}, format::bfzyx, cldnn::gather::gather_axis::along_b, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
#define CASE_GATHER_5D_FP16_2 {1, 3, 1, 2, 1}, {2, 1, 1, 1}, {1, 2, 1, 2, 1}, cldnn::gather::gather_axis::along_f, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx #define CASE_GATHER_5D_FP16_2 {1, 3, 1, 2, 1}, {2, 1, 1, 1}, {1, 2, 1, 2, 1}, format::bfzyx, cldnn::gather::gather_axis::along_f, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
#define CASE_GATHER_5D_FP16_3 {2, 3, 1, 3, 3}, {1, 2, 1, 1}, {2, 3, 1, 2, 3}, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx #define CASE_GATHER_5D_FP16_3 {2, 3, 1, 3, 3}, {1, 2, 1, 1}, {2, 3, 1, 2, 3}, format::bfzyx, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
#define CASE_GATHER_5D_FP16_4 {3, 2, 2, 2, 2}, {2, 1, 1, 1}, {3, 2, 2, 2, 2}, cldnn::gather::gather_axis::along_z, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx #define CASE_GATHER_5D_FP16_4 {3, 2, 2, 2, 2}, {2, 1, 1, 1}, {3, 2, 2, 2, 2}, format::bfzyx, cldnn::gather::gather_axis::along_z, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
#define CASE_GATHER_5D_FP16_5 {1, 1, 2, 1, 1}, {3, 1, 1, 1}, {1, 1, 3, 1, 1}, cldnn::gather::gather_axis::along_x, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx #define CASE_GATHER_5D_FP16_5 {1, 1, 2, 1, 1}, {3, 1, 1, 1}, {1, 1, 3, 1, 1}, format::bfzyx, cldnn::gather::gather_axis::along_x, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
class GatherPrimitiveFusingTest : public ::BaseFusingTest<gather_test_params> { class GatherPrimitiveFusingTest : public ::BaseFusingTest<gather_test_params> {
public: public:
@ -5398,7 +5399,7 @@ TEST_P(gather_quantize, basic) {
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)), data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
data("out_lo", get_mem(get_single_element_layout(p), -127)), data("out_lo", get_mem(get_single_element_layout(p), -127)),
data("out_hi", get_mem(get_single_element_layout(p), 127)), data("out_hi", get_mem(get_single_element_layout(p), 127)),
gather("gather_prim", "input", "gather_indices", p.axis, p.out_shape), gather("gather_prim", "input", "gather_indices", p.axis, p.out_format, p.out_shape),
quantize("quantize", "gather_prim", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8), quantize("quantize", "gather_prim", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8),
reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32) reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
); );
@ -5440,7 +5441,7 @@ TEST_P(gather_scale_activation, basic) {
create_topologies(input_layout("input", get_input_layout(p)), create_topologies(input_layout("input", get_input_layout(p)),
data("gather_indices", get_mem(get_indices_layout(p), 0, static_cast<int>(get_axis_dim(p)))), data("gather_indices", get_mem(get_indices_layout(p), 0, static_cast<int>(get_axis_dim(p)))),
data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)), data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)),
gather("gather_prim", "input", "gather_indices", p.axis, p.out_shape), gather("gather_prim", "input", "gather_indices", p.axis, p.out_format, p.out_shape),
activation("activation", "gather_prim", activation_func::abs), activation("activation", "gather_prim", activation_func::abs),
scale("scale", "activation", "scale_data"), scale("scale", "activation", "scale_data"),
reorder("reorder_bfyx", "scale", p.default_format, data_types::f32) reorder("reorder_bfyx", "scale", p.default_format, data_types::f32)

View File

@ -63,7 +63,7 @@ TEST(gather_gpu_fp16, d14_axisB) {
topology.add(input_layout("InputDictionary", input1.get_layout())); topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout())); topology.add(input_layout("InputText", input2.get_layout()));
topology.add( topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(1, 4, 1, 2)) gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(1, 4, 1, 2))
); );
network network(engine, topology); network network(engine, topology);
@ -125,7 +125,7 @@ TEST(gather_gpu_fp16, d222_axisB) {
topology.add(input_layout("InputDictionary", input1.get_layout())); topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout())); topology.add(input_layout("InputText", input2.get_layout()));
topology.add( topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2)) gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 2, 2, 2))
); );
network network(engine, topology); network network(engine, topology);
@ -186,7 +186,7 @@ TEST(gather_gpu_fp16, d22_axisY) {
topology.add(input_layout("InputDictionary", input1.get_layout())); topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout())); topology.add(input_layout("InputText", input2.get_layout()));
topology.add( topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2)) gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 2, 2, 2))
); );
network network(engine, topology); network network(engine, topology);
@ -247,7 +247,7 @@ TEST(gather_gpu_fp16, d22_axisF) {
topology.add(input_layout("InputDictionary", input1.get_layout())); topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout())); topology.add(input_layout("InputText", input2.get_layout()));
topology.add( topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 1, 2)) gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 2, 2, 2))
); );
network network(engine, topology); network network(engine, topology);
@ -305,7 +305,7 @@ TEST(gather_gpu_fp32, d14_axisB) {
topology.add(input_layout("InputDictionary", input1.get_layout())); topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout())); topology.add(input_layout("InputText", input2.get_layout()));
topology.add( topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(1, 4, 1, 2)) gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(1, 4, 1, 2))
); );
network network(engine, topology); network network(engine, topology);
@ -366,7 +366,7 @@ TEST(gather_gpu_fp32, d222_axisB) {
topology.add(input_layout("InputDictionary", input1.get_layout())); topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout())); topology.add(input_layout("InputText", input2.get_layout()));
topology.add( topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2)) gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 2, 2, 2))
); );
network network(engine, topology); network network(engine, topology);
@ -427,7 +427,7 @@ TEST(gather_gpu_fp32, d22_axisY) {
topology.add(input_layout("InputDictionary", input1.get_layout())); topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout())); topology.add(input_layout("InputText", input2.get_layout()));
topology.add( topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2)) gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 2, 2, 2))
); );
network network(engine, topology); network network(engine, topology);
@ -488,7 +488,7 @@ TEST(gather_gpu_fp32, d22_axisF) {
topology.add(input_layout("InputDictionary", input1.get_layout())); topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout())); topology.add(input_layout("InputText", input2.get_layout()));
topology.add( topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 1, 2)) gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 2, 2, 2))
); );
network network(engine, topology); network network(engine, topology);
@ -549,7 +549,7 @@ TEST(gather_gpu_int32, d22_axisF) {
topology.add(input_layout("InputDictionary", input1.get_layout())); topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout())); topology.add(input_layout("InputText", input2.get_layout()));
topology.add( topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 1, 2)) gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 2, 2, 2))
); );
network network(engine, topology); network network(engine, topology);
@ -607,7 +607,7 @@ TEST(gather_gpu_int32, d14_axisB) {
topology.add(input_layout("InputDictionary", input1.get_layout())); topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout())); topology.add(input_layout("InputText", input2.get_layout()));
topology.add( topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(1, 4, 1, 2)) gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(1, 4, 1, 2))
); );
network network(engine, topology); network network(engine, topology);
@ -668,7 +668,7 @@ TEST(gather_gpu_int32, d222_axisB) {
topology.add(input_layout("InputDictionary", input1.get_layout())); topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout())); topology.add(input_layout("InputText", input2.get_layout()));
topology.add( topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2)) gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 2, 2, 2))
); );
network network(engine, topology); network network(engine, topology);
@ -729,7 +729,7 @@ TEST(gather_gpu_int32, d22_axisY) {
topology.add(input_layout("InputDictionary", input1.get_layout())); topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout())); topology.add(input_layout("InputText", input2.get_layout()));
topology.add( topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2)) gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 2, 2, 2))
); );
network network(engine, topology); network network(engine, topology);
@ -793,7 +793,7 @@ TEST(gather_gpu_fp32, d41_axisB) {
topology.add(input_layout("InputDictionary", input1.get_layout())); topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout())); topology.add(input_layout("InputText", input2.get_layout()));
topology.add( topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(4, 1, 3, 2)) gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(4, 1, 3, 2))
); );
network network(engine, topology); network network(engine, topology);
@ -856,7 +856,7 @@ TEST(gather_gpu_fp32, d41_axisF) {
topology.add(input_layout("InputDictionary", input1.get_layout())); topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout())); topology.add(input_layout("InputText", input2.get_layout()));
topology.add( topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 4, 2, 1)) gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 4, 2, 1))
); );
network network(engine, topology); network network(engine, topology);
@ -915,7 +915,7 @@ TEST(gather_gpu_fp32, d2_axisX) {
topology.add(input_layout("InputDictionary", input1.get_layout())); topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout())); topology.add(input_layout("InputText", input2.get_layout()));
topology.add( topology.add(
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 1)) gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 2, 2, 1))
); );
network network(engine, topology); network network(engine, topology);
@ -938,3 +938,52 @@ TEST(gather_gpu_fp32, d2_axisX) {
EXPECT_EQ(expected_results[i], output_ptr[i]) << " at i=" << i; EXPECT_EQ(expected_results[i], output_ptr[i]) << " at i=" << i;
} }
} }
TEST(gather_gpu_fp32, 322_axisF) {
// Dictionary : 3x3x1x1
// Indexes : 2x2x1x1
// Axis : 1
// Output : 3x2x2x1
// Input values in i32
engine engine;
auto input1 = memory::allocate(engine, { data_types::i32, format::bfyx, { 3, 3, 1, 1 } }); // data
auto input2 = memory::allocate(engine, { data_types::i32, format::bfyx, { 2, 2, 1, 1 } }); // Indexes
auto axis = cldnn::gather::gather_axis::along_f;
set_values(input1, {
0, 1, 2, 10, 11, 12, 20, 21, 22
});
set_values(input2, {
1, 0,
2, 1
});
topology topology;
topology.add(input_layout("InputDictionary", input1.get_layout()));
topology.add(input_layout("InputText", input2.get_layout()));
topology.add(
gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(3, 2, 1, 2))
);
network network(engine, topology);
network.set_input_data("InputDictionary", input1);
network.set_input_data("InputText", input2);
auto outputs = network.execute();
auto output = outputs.at("gather").get_memory();
auto output_ptr = output.pointer<int>();
std::vector<int> expected_results = {
1, 0, 2, 1, 11, 10, 12, 11, 21, 20, 22, 21
};
ASSERT_EQ(expected_results.size(), output_ptr.size());
for (size_t i = 0; i < expected_results.size(); ++i) {
EXPECT_EQ(expected_results[i], output_ptr[i]) << i;
}
}