[IE CLDNN] Gather fix (#3273)
This commit is contained in:
parent
9b85d67cce
commit
5de04e9b94
@ -4302,12 +4302,14 @@ void Program::CreateGatherPrimitive(cldnn::topology& topology, InferenceEngine::
|
|||||||
|
|
||||||
auto inputLayout = layer->insData[0].lock()->getTensorDesc().getLayout();
|
auto inputLayout = layer->insData[0].lock()->getTensorDesc().getLayout();
|
||||||
auto outDims = layer->outData[0]->getTensorDesc().getDims();
|
auto outDims = layer->outData[0]->getTensorDesc().getDims();
|
||||||
|
auto outLayout = layer->outData[0]->getTensorDesc().getLayout();
|
||||||
|
|
||||||
auto gatherPrim = cldnn::gather(
|
auto gatherPrim = cldnn::gather(
|
||||||
gatherLayerName,
|
gatherLayerName,
|
||||||
reorderedInputs[0],
|
reorderedInputs[0],
|
||||||
reorderedInputs[1],
|
reorderedInputs[1],
|
||||||
cldnnAxisFromIE(axis, FormatFromLayout(inputLayout)),
|
cldnnAxisFromIE(axis, FormatFromLayout(inputLayout)),
|
||||||
|
FormatFromLayout(outLayout),
|
||||||
CldnnTensorFromIEDims(outDims));
|
CldnnTensorFromIEDims(outDims));
|
||||||
|
|
||||||
topology.add(gatherPrim);
|
topology.add(gatherPrim);
|
||||||
|
@ -197,6 +197,7 @@ INSTANTIATE_TEST_CASE_P(
|
|||||||
);
|
);
|
||||||
|
|
||||||
const std::vector<std::vector<size_t>> inputShapesAxes2 = {
|
const std::vector<std::vector<size_t>> inputShapesAxes2 = {
|
||||||
|
std::vector<size_t>{5, 6, 7},
|
||||||
std::vector<size_t>{5, 6, 7, 8},
|
std::vector<size_t>{5, 6, 7, 8},
|
||||||
std::vector<size_t>{1, 6, 7, 8},
|
std::vector<size_t>{1, 6, 7, 8},
|
||||||
std::vector<size_t>{5, 1, 7, 8},
|
std::vector<size_t>{5, 1, 7, 8},
|
||||||
@ -265,6 +266,8 @@ INSTANTIATE_TEST_CASE_P(
|
|||||||
);
|
);
|
||||||
|
|
||||||
const std::vector<std::vector<size_t>> inputShapesAxes1 = {
|
const std::vector<std::vector<size_t>> inputShapesAxes1 = {
|
||||||
|
std::vector<size_t>{5, 6},
|
||||||
|
std::vector<size_t>{5, 6, 7},
|
||||||
std::vector<size_t>{5, 6, 7, 8},
|
std::vector<size_t>{5, 6, 7, 8},
|
||||||
std::vector<size_t>{1, 6, 7, 8},
|
std::vector<size_t>{1, 6, 7, 8},
|
||||||
std::vector<size_t>{5, 6, 1, 8},
|
std::vector<size_t>{5, 6, 1, 8},
|
||||||
@ -333,6 +336,9 @@ INSTANTIATE_TEST_CASE_P(
|
|||||||
);
|
);
|
||||||
|
|
||||||
const std::vector<std::vector<size_t>> inputShapesAxes0 = {
|
const std::vector<std::vector<size_t>> inputShapesAxes0 = {
|
||||||
|
std::vector<size_t>{5},
|
||||||
|
std::vector<size_t>{5, 6},
|
||||||
|
std::vector<size_t>{5, 6, 7},
|
||||||
std::vector<size_t>{5, 6, 7, 8},
|
std::vector<size_t>{5, 6, 7, 8},
|
||||||
std::vector<size_t>{5, 1, 7, 8},
|
std::vector<size_t>{5, 1, 7, 8},
|
||||||
std::vector<size_t>{5, 6, 1, 8},
|
std::vector<size_t>{5, 6, 1, 8},
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
// Copyright (c) 2019 Intel Corporation
|
// Copyright (c) 2019-2020 Intel Corporation
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
@ -50,13 +50,16 @@ struct gather : public primitive_base<gather> {
|
|||||||
const primitive_id& dict,
|
const primitive_id& dict,
|
||||||
const primitive_id& idx,
|
const primitive_id& idx,
|
||||||
const gather_axis axis,
|
const gather_axis axis,
|
||||||
|
const format& output_format,
|
||||||
const tensor& output_shape,
|
const tensor& output_shape,
|
||||||
const padding& output_padding = padding())
|
const padding& output_padding = padding())
|
||||||
: primitive_base(id, {dict, idx}, output_padding), axis(axis), output_shape(output_shape) {}
|
: primitive_base(id, {dict, idx}, output_padding), axis(axis), output_format(output_format), output_shape(output_shape) {}
|
||||||
|
|
||||||
/// @brief Gathering axis
|
/// @brief Gathering axis
|
||||||
gather_axis axis;
|
gather_axis axis;
|
||||||
/// @brief Gathering input shape
|
/// @brief Gather output format
|
||||||
|
format output_format;
|
||||||
|
/// @brief Gather output shape
|
||||||
tensor output_shape;
|
tensor output_shape;
|
||||||
};
|
};
|
||||||
/// @}
|
/// @}
|
||||||
|
15
inference-engine/thirdparty/clDNN/src/gather.cpp
vendored
15
inference-engine/thirdparty/clDNN/src/gather.cpp
vendored
@ -31,21 +31,8 @@ layout gather_inst::calc_output_layout(gather_node const& node) {
|
|||||||
auto desc = node.get_primitive();
|
auto desc = node.get_primitive();
|
||||||
|
|
||||||
auto input_layout = node.input(0).get_output_layout();
|
auto input_layout = node.input(0).get_output_layout();
|
||||||
|
auto output_format = desc->output_format;
|
||||||
auto output_shape = desc->output_shape;
|
auto output_shape = desc->output_shape;
|
||||||
auto output_format = input_layout.format;
|
|
||||||
|
|
||||||
int spatialNum = 0;
|
|
||||||
for (auto i : node.input(1).get_output_layout().size.raw)
|
|
||||||
spatialNum += (i > 1) ? 1 : 0;
|
|
||||||
|
|
||||||
// change output format if input indeces > 1
|
|
||||||
if (spatialNum == 2 && output_format == cldnn::format::bfzyx) {
|
|
||||||
output_format = cldnn::format::bfwzyx;
|
|
||||||
} else if (spatialNum == 2 && output_format == cldnn::format::bfyx) {
|
|
||||||
output_format = cldnn::format::bfzyx;
|
|
||||||
} else if (spatialNum == 3 && output_format == cldnn::format::bfyx) {
|
|
||||||
output_format = cldnn::format::bfwzyx;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto output_type = input_layout.data_type;
|
auto output_type = input_layout.data_type;
|
||||||
if (node.has_fused_primitives()) {
|
if (node.has_fused_primitives()) {
|
||||||
|
@ -5312,6 +5312,7 @@ struct gather_test_params {
|
|||||||
tensor dictionary_shape;
|
tensor dictionary_shape;
|
||||||
tensor indices_shape;
|
tensor indices_shape;
|
||||||
tensor out_shape;
|
tensor out_shape;
|
||||||
|
format out_format;
|
||||||
cldnn::gather::gather_axis axis;
|
cldnn::gather::gather_axis axis;
|
||||||
data_types data_type;
|
data_types data_type;
|
||||||
format input_format;
|
format input_format;
|
||||||
@ -5321,29 +5322,29 @@ struct gather_test_params {
|
|||||||
size_t expected_not_fused_primitives;
|
size_t expected_not_fused_primitives;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define CASE_GATHER_FP32_1 {2, 3, 1, 4}, {4, 1, 1, 1}, {4, 3, 1, 4}, cldnn::gather::gather_axis::along_b, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
#define CASE_GATHER_FP32_1 {2, 3, 1, 4}, {4, 1, 1, 1}, {4, 3, 1, 4}, format::bfyx, cldnn::gather::gather_axis::along_b, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||||
#define CASE_GATHER_FP32_2 {3, 2, 1, 2}, {2, 3, 1, 1}, {2, 3, 2, 2}, cldnn::gather::gather_axis::along_b, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
#define CASE_GATHER_FP32_2 {3, 2, 1, 2}, {2, 3, 1, 1}, {2, 3, 2, 2}, format::bfyx, cldnn::gather::gather_axis::along_b, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||||
#define CASE_GATHER_FP32_3 {3, 1, 1, 2}, {2, 1, 1, 1}, {3, 2, 1, 2}, cldnn::gather::gather_axis::along_f, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
#define CASE_GATHER_FP32_3 {3, 1, 1, 2}, {2, 1, 1, 1}, {3, 2, 1, 2}, format::bfyx, cldnn::gather::gather_axis::along_f, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||||
#define CASE_GATHER_FP32_4 {5, 3, 2, 2}, {3, 1, 1, 1}, {5, 2, 2, 3}, cldnn::gather::gather_axis::along_y, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
#define CASE_GATHER_FP32_4 {5, 3, 2, 2}, {3, 1, 1, 1}, {5, 2, 2, 3}, format::bfyx, cldnn::gather::gather_axis::along_y, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||||
#define CASE_GATHER_FP32_5 {2, 3, 1, 2}, {1, 3, 1, 1}, {2, 3, 3, 1}, cldnn::gather::gather_axis::along_y, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
#define CASE_GATHER_FP32_5 {2, 3, 1, 2}, {1, 3, 1, 1}, {2, 3, 3, 1}, format::bfyx, cldnn::gather::gather_axis::along_y, data_types::f32, format::bfyx, data_types::f32, format::bfyx
|
||||||
|
|
||||||
#define CASE_GATHER_FP16_1 {2, 3, 1, 4}, {4, 1, 1, 1}, {4, 3, 1, 4}, cldnn::gather::gather_axis::along_b, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
#define CASE_GATHER_FP16_1 {2, 3, 1, 4}, {4, 1, 1, 1}, {4, 3, 1, 4}, format::bfyx, cldnn::gather::gather_axis::along_b, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||||
#define CASE_GATHER_FP16_2 {3, 2, 1, 2}, {2, 3, 1, 1}, {2, 3, 2, 2}, cldnn::gather::gather_axis::along_b, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
#define CASE_GATHER_FP16_2 {3, 2, 1, 2}, {2, 3, 1, 1}, {2, 3, 2, 2}, format::bfyx, cldnn::gather::gather_axis::along_b, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||||
#define CASE_GATHER_FP16_3 {3, 1, 1, 2}, {2, 1, 1, 1}, {3, 2, 1, 2}, cldnn::gather::gather_axis::along_f, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
#define CASE_GATHER_FP16_3 {3, 1, 1, 2}, {2, 1, 1, 1}, {3, 2, 1, 2}, format::bfyx, cldnn::gather::gather_axis::along_f, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||||
#define CASE_GATHER_FP16_4 {5, 3, 2, 2}, {3, 1, 1, 1}, {5, 2, 2, 3}, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
#define CASE_GATHER_FP16_4 {5, 3, 2, 2}, {3, 1, 1, 1}, {5, 2, 2, 3}, format::bfyx, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||||
#define CASE_GATHER_FP16_5 {2, 3, 1, 2}, {1, 3, 1, 1}, {2, 3, 3, 1}, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
#define CASE_GATHER_FP16_5 {2, 3, 1, 2}, {1, 3, 1, 1}, {2, 3, 3, 1}, format::bfyx, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||||
|
|
||||||
#define CASE_GATHER_5D_FP32_1 {2, 3, 1, 4, 1}, {4, 1, 1, 1}, {4, 3, 1, 4, 1}, cldnn::gather::gather_axis::along_b, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
|
#define CASE_GATHER_5D_FP32_1 {2, 3, 1, 4, 1}, {4, 1, 1, 1}, {4, 3, 1, 4, 1}, format::bfzyx, cldnn::gather::gather_axis::along_b, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
|
||||||
#define CASE_GATHER_5D_FP32_2 {2, 3, 2, 2, 2}, {2, 1, 1, 1}, {2, 2, 2, 2, 2}, cldnn::gather::gather_axis::along_f, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
|
#define CASE_GATHER_5D_FP32_2 {2, 3, 2, 2, 2}, {2, 1, 1, 1}, {2, 2, 2, 2, 2}, format::bfzyx, cldnn::gather::gather_axis::along_f, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
|
||||||
#define CASE_GATHER_5D_FP32_3 {5, 3, 2, 2, 2}, {3, 1, 1, 1}, {5, 3, 2, 3, 2}, cldnn::gather::gather_axis::along_y, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
|
#define CASE_GATHER_5D_FP32_3 {5, 3, 2, 2, 2}, {3, 1, 1, 1}, {5, 3, 2, 3, 2}, format::bfzyx, cldnn::gather::gather_axis::along_y, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
|
||||||
#define CASE_GATHER_5D_FP32_4 {2, 3, 1, 4, 4}, {2, 1, 1, 1}, {2, 3, 1, 4, 2}, cldnn::gather::gather_axis::along_z, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
|
#define CASE_GATHER_5D_FP32_4 {2, 3, 1, 4, 4}, {2, 1, 1, 1}, {2, 3, 1, 4, 2}, format::bfzyx, cldnn::gather::gather_axis::along_z, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
|
||||||
#define CASE_GATHER_5D_FP32_5 {3, 1, 5, 2, 1}, {2, 1, 1, 1}, {3, 1, 2, 2, 1}, cldnn::gather::gather_axis::along_x, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
|
#define CASE_GATHER_5D_FP32_5 {3, 1, 5, 2, 1}, {2, 1, 1, 1}, {3, 1, 2, 2, 1}, format::bfzyx, cldnn::gather::gather_axis::along_x, data_types::f32, format::bfzyx, data_types::f32, format::bfzyx
|
||||||
|
|
||||||
#define CASE_GATHER_5D_FP16_1 {3, 2, 1, 2, 1}, {2, 1, 1, 1}, {2, 2, 2, 2, 1}, cldnn::gather::gather_axis::along_b, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
|
#define CASE_GATHER_5D_FP16_1 {3, 2, 1, 2, 1}, {2, 1, 1, 1}, {2, 2, 2, 2, 1}, format::bfzyx, cldnn::gather::gather_axis::along_b, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
|
||||||
#define CASE_GATHER_5D_FP16_2 {1, 3, 1, 2, 1}, {2, 1, 1, 1}, {1, 2, 1, 2, 1}, cldnn::gather::gather_axis::along_f, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
|
#define CASE_GATHER_5D_FP16_2 {1, 3, 1, 2, 1}, {2, 1, 1, 1}, {1, 2, 1, 2, 1}, format::bfzyx, cldnn::gather::gather_axis::along_f, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
|
||||||
#define CASE_GATHER_5D_FP16_3 {2, 3, 1, 3, 3}, {1, 2, 1, 1}, {2, 3, 1, 2, 3}, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
|
#define CASE_GATHER_5D_FP16_3 {2, 3, 1, 3, 3}, {1, 2, 1, 1}, {2, 3, 1, 2, 3}, format::bfzyx, cldnn::gather::gather_axis::along_y, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
|
||||||
#define CASE_GATHER_5D_FP16_4 {3, 2, 2, 2, 2}, {2, 1, 1, 1}, {3, 2, 2, 2, 2}, cldnn::gather::gather_axis::along_z, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
|
#define CASE_GATHER_5D_FP16_4 {3, 2, 2, 2, 2}, {2, 1, 1, 1}, {3, 2, 2, 2, 2}, format::bfzyx, cldnn::gather::gather_axis::along_z, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
|
||||||
#define CASE_GATHER_5D_FP16_5 {1, 1, 2, 1, 1}, {3, 1, 1, 1}, {1, 1, 3, 1, 1}, cldnn::gather::gather_axis::along_x, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
|
#define CASE_GATHER_5D_FP16_5 {1, 1, 2, 1, 1}, {3, 1, 1, 1}, {1, 1, 3, 1, 1}, format::bfzyx, cldnn::gather::gather_axis::along_x, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
|
||||||
|
|
||||||
class GatherPrimitiveFusingTest : public ::BaseFusingTest<gather_test_params> {
|
class GatherPrimitiveFusingTest : public ::BaseFusingTest<gather_test_params> {
|
||||||
public:
|
public:
|
||||||
@ -5398,7 +5399,7 @@ TEST_P(gather_quantize, basic) {
|
|||||||
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
|
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
|
||||||
data("out_lo", get_mem(get_single_element_layout(p), -127)),
|
data("out_lo", get_mem(get_single_element_layout(p), -127)),
|
||||||
data("out_hi", get_mem(get_single_element_layout(p), 127)),
|
data("out_hi", get_mem(get_single_element_layout(p), 127)),
|
||||||
gather("gather_prim", "input", "gather_indices", p.axis, p.out_shape),
|
gather("gather_prim", "input", "gather_indices", p.axis, p.out_format, p.out_shape),
|
||||||
quantize("quantize", "gather_prim", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8),
|
quantize("quantize", "gather_prim", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8),
|
||||||
reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
|
reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
|
||||||
);
|
);
|
||||||
@ -5440,7 +5441,7 @@ TEST_P(gather_scale_activation, basic) {
|
|||||||
create_topologies(input_layout("input", get_input_layout(p)),
|
create_topologies(input_layout("input", get_input_layout(p)),
|
||||||
data("gather_indices", get_mem(get_indices_layout(p), 0, static_cast<int>(get_axis_dim(p)))),
|
data("gather_indices", get_mem(get_indices_layout(p), 0, static_cast<int>(get_axis_dim(p)))),
|
||||||
data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)),
|
data("scale_data", get_mem(get_per_channel_layout(p), -10, 10)),
|
||||||
gather("gather_prim", "input", "gather_indices", p.axis, p.out_shape),
|
gather("gather_prim", "input", "gather_indices", p.axis, p.out_format, p.out_shape),
|
||||||
activation("activation", "gather_prim", activation_func::abs),
|
activation("activation", "gather_prim", activation_func::abs),
|
||||||
scale("scale", "activation", "scale_data"),
|
scale("scale", "activation", "scale_data"),
|
||||||
reorder("reorder_bfyx", "scale", p.default_format, data_types::f32)
|
reorder("reorder_bfyx", "scale", p.default_format, data_types::f32)
|
||||||
|
@ -63,7 +63,7 @@ TEST(gather_gpu_fp16, d14_axisB) {
|
|||||||
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
||||||
topology.add(input_layout("InputText", input2.get_layout()));
|
topology.add(input_layout("InputText", input2.get_layout()));
|
||||||
topology.add(
|
topology.add(
|
||||||
gather("gather", "InputDictionary", "InputText", axis, tensor(1, 4, 1, 2))
|
gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(1, 4, 1, 2))
|
||||||
);
|
);
|
||||||
|
|
||||||
network network(engine, topology);
|
network network(engine, topology);
|
||||||
@ -125,7 +125,7 @@ TEST(gather_gpu_fp16, d222_axisB) {
|
|||||||
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
||||||
topology.add(input_layout("InputText", input2.get_layout()));
|
topology.add(input_layout("InputText", input2.get_layout()));
|
||||||
topology.add(
|
topology.add(
|
||||||
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
|
gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 2, 2, 2))
|
||||||
);
|
);
|
||||||
|
|
||||||
network network(engine, topology);
|
network network(engine, topology);
|
||||||
@ -186,7 +186,7 @@ TEST(gather_gpu_fp16, d22_axisY) {
|
|||||||
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
||||||
topology.add(input_layout("InputText", input2.get_layout()));
|
topology.add(input_layout("InputText", input2.get_layout()));
|
||||||
topology.add(
|
topology.add(
|
||||||
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
|
gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 2, 2, 2))
|
||||||
);
|
);
|
||||||
|
|
||||||
network network(engine, topology);
|
network network(engine, topology);
|
||||||
@ -247,7 +247,7 @@ TEST(gather_gpu_fp16, d22_axisF) {
|
|||||||
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
||||||
topology.add(input_layout("InputText", input2.get_layout()));
|
topology.add(input_layout("InputText", input2.get_layout()));
|
||||||
topology.add(
|
topology.add(
|
||||||
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 1, 2))
|
gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 2, 2, 2))
|
||||||
);
|
);
|
||||||
|
|
||||||
network network(engine, topology);
|
network network(engine, topology);
|
||||||
@ -305,7 +305,7 @@ TEST(gather_gpu_fp32, d14_axisB) {
|
|||||||
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
||||||
topology.add(input_layout("InputText", input2.get_layout()));
|
topology.add(input_layout("InputText", input2.get_layout()));
|
||||||
topology.add(
|
topology.add(
|
||||||
gather("gather", "InputDictionary", "InputText", axis, tensor(1, 4, 1, 2))
|
gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(1, 4, 1, 2))
|
||||||
);
|
);
|
||||||
|
|
||||||
network network(engine, topology);
|
network network(engine, topology);
|
||||||
@ -366,7 +366,7 @@ TEST(gather_gpu_fp32, d222_axisB) {
|
|||||||
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
||||||
topology.add(input_layout("InputText", input2.get_layout()));
|
topology.add(input_layout("InputText", input2.get_layout()));
|
||||||
topology.add(
|
topology.add(
|
||||||
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
|
gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 2, 2, 2))
|
||||||
);
|
);
|
||||||
|
|
||||||
network network(engine, topology);
|
network network(engine, topology);
|
||||||
@ -427,7 +427,7 @@ TEST(gather_gpu_fp32, d22_axisY) {
|
|||||||
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
||||||
topology.add(input_layout("InputText", input2.get_layout()));
|
topology.add(input_layout("InputText", input2.get_layout()));
|
||||||
topology.add(
|
topology.add(
|
||||||
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
|
gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 2, 2, 2))
|
||||||
);
|
);
|
||||||
|
|
||||||
network network(engine, topology);
|
network network(engine, topology);
|
||||||
@ -488,7 +488,7 @@ TEST(gather_gpu_fp32, d22_axisF) {
|
|||||||
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
||||||
topology.add(input_layout("InputText", input2.get_layout()));
|
topology.add(input_layout("InputText", input2.get_layout()));
|
||||||
topology.add(
|
topology.add(
|
||||||
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 1, 2))
|
gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 2, 2, 2))
|
||||||
);
|
);
|
||||||
|
|
||||||
network network(engine, topology);
|
network network(engine, topology);
|
||||||
@ -549,7 +549,7 @@ TEST(gather_gpu_int32, d22_axisF) {
|
|||||||
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
||||||
topology.add(input_layout("InputText", input2.get_layout()));
|
topology.add(input_layout("InputText", input2.get_layout()));
|
||||||
topology.add(
|
topology.add(
|
||||||
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 1, 2))
|
gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 2, 2, 2))
|
||||||
);
|
);
|
||||||
|
|
||||||
network network(engine, topology);
|
network network(engine, topology);
|
||||||
@ -607,7 +607,7 @@ TEST(gather_gpu_int32, d14_axisB) {
|
|||||||
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
||||||
topology.add(input_layout("InputText", input2.get_layout()));
|
topology.add(input_layout("InputText", input2.get_layout()));
|
||||||
topology.add(
|
topology.add(
|
||||||
gather("gather", "InputDictionary", "InputText", axis, tensor(1, 4, 1, 2))
|
gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(1, 4, 1, 2))
|
||||||
);
|
);
|
||||||
|
|
||||||
network network(engine, topology);
|
network network(engine, topology);
|
||||||
@ -668,7 +668,7 @@ TEST(gather_gpu_int32, d222_axisB) {
|
|||||||
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
||||||
topology.add(input_layout("InputText", input2.get_layout()));
|
topology.add(input_layout("InputText", input2.get_layout()));
|
||||||
topology.add(
|
topology.add(
|
||||||
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
|
gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 2, 2, 2))
|
||||||
);
|
);
|
||||||
|
|
||||||
network network(engine, topology);
|
network network(engine, topology);
|
||||||
@ -729,7 +729,7 @@ TEST(gather_gpu_int32, d22_axisY) {
|
|||||||
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
||||||
topology.add(input_layout("InputText", input2.get_layout()));
|
topology.add(input_layout("InputText", input2.get_layout()));
|
||||||
topology.add(
|
topology.add(
|
||||||
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 1, 2, 2))
|
gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 2, 2, 2))
|
||||||
);
|
);
|
||||||
|
|
||||||
network network(engine, topology);
|
network network(engine, topology);
|
||||||
@ -793,7 +793,7 @@ TEST(gather_gpu_fp32, d41_axisB) {
|
|||||||
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
||||||
topology.add(input_layout("InputText", input2.get_layout()));
|
topology.add(input_layout("InputText", input2.get_layout()));
|
||||||
topology.add(
|
topology.add(
|
||||||
gather("gather", "InputDictionary", "InputText", axis, tensor(4, 1, 3, 2))
|
gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(4, 1, 3, 2))
|
||||||
);
|
);
|
||||||
|
|
||||||
network network(engine, topology);
|
network network(engine, topology);
|
||||||
@ -856,7 +856,7 @@ TEST(gather_gpu_fp32, d41_axisF) {
|
|||||||
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
||||||
topology.add(input_layout("InputText", input2.get_layout()));
|
topology.add(input_layout("InputText", input2.get_layout()));
|
||||||
topology.add(
|
topology.add(
|
||||||
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 4, 2, 1))
|
gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 4, 2, 1))
|
||||||
);
|
);
|
||||||
|
|
||||||
network network(engine, topology);
|
network network(engine, topology);
|
||||||
@ -915,7 +915,7 @@ TEST(gather_gpu_fp32, d2_axisX) {
|
|||||||
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
||||||
topology.add(input_layout("InputText", input2.get_layout()));
|
topology.add(input_layout("InputText", input2.get_layout()));
|
||||||
topology.add(
|
topology.add(
|
||||||
gather("gather", "InputDictionary", "InputText", axis, tensor(2, 2, 2, 1))
|
gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(2, 2, 2, 1))
|
||||||
);
|
);
|
||||||
|
|
||||||
network network(engine, topology);
|
network network(engine, topology);
|
||||||
@ -938,3 +938,52 @@ TEST(gather_gpu_fp32, d2_axisX) {
|
|||||||
EXPECT_EQ(expected_results[i], output_ptr[i]) << " at i=" << i;
|
EXPECT_EQ(expected_results[i], output_ptr[i]) << " at i=" << i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(gather_gpu_fp32, 322_axisF) {
|
||||||
|
// Dictionary : 3x3x1x1
|
||||||
|
// Indexes : 2x2x1x1
|
||||||
|
// Axis : 1
|
||||||
|
// Output : 3x2x2x1
|
||||||
|
// Input values in i32
|
||||||
|
|
||||||
|
engine engine;
|
||||||
|
|
||||||
|
auto input1 = memory::allocate(engine, { data_types::i32, format::bfyx, { 3, 3, 1, 1 } }); // data
|
||||||
|
auto input2 = memory::allocate(engine, { data_types::i32, format::bfyx, { 2, 2, 1, 1 } }); // Indexes
|
||||||
|
auto axis = cldnn::gather::gather_axis::along_f;
|
||||||
|
|
||||||
|
set_values(input1, {
|
||||||
|
0, 1, 2, 10, 11, 12, 20, 21, 22
|
||||||
|
});
|
||||||
|
|
||||||
|
set_values(input2, {
|
||||||
|
1, 0,
|
||||||
|
2, 1
|
||||||
|
});
|
||||||
|
|
||||||
|
topology topology;
|
||||||
|
topology.add(input_layout("InputDictionary", input1.get_layout()));
|
||||||
|
topology.add(input_layout("InputText", input2.get_layout()));
|
||||||
|
topology.add(
|
||||||
|
gather("gather", "InputDictionary", "InputText", axis, format::bfyx, tensor(3, 2, 1, 2))
|
||||||
|
);
|
||||||
|
|
||||||
|
network network(engine, topology);
|
||||||
|
|
||||||
|
network.set_input_data("InputDictionary", input1);
|
||||||
|
network.set_input_data("InputText", input2);
|
||||||
|
|
||||||
|
auto outputs = network.execute();
|
||||||
|
|
||||||
|
auto output = outputs.at("gather").get_memory();
|
||||||
|
auto output_ptr = output.pointer<int>();
|
||||||
|
|
||||||
|
std::vector<int> expected_results = {
|
||||||
|
1, 0, 2, 1, 11, 10, 12, 11, 21, 20, 22, 21
|
||||||
|
};
|
||||||
|
|
||||||
|
ASSERT_EQ(expected_results.size(), output_ptr.size());
|
||||||
|
for (size_t i = 0; i < expected_results.size(); ++i) {
|
||||||
|
EXPECT_EQ(expected_results[i], output_ptr[i]) << i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user