[GPU] Update GatherND primitive (#8813)

* Cldnn output memory size at GatherND functional-test is aligned with TensorDesc of output blob
* Add param for rank of input data
* Update unittests to add rank of input data
* Update gpu fusing tests
This commit is contained in:
Kelvin Choi 2021-11-29 17:51:19 +09:00 committed by GitHub
parent d50f20b977
commit 2d996c1354
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 46 additions and 6 deletions

View File

@ -17,6 +17,7 @@ static void CreateGatherNDOp(Program& p, const std::shared_ptr<ngraph::op::v5::G
auto inputPrimitives = p.GetInputPrimitiveIDs(op);
std::string layerName = layer_type_name_ID(op);
int32_t input_rank = static_cast<int32_t>(op->get_input_shape(0).size());
int32_t indices_rank = static_cast<int32_t>(op->get_input_shape(1).size());
auto batch_dims = op->get_batch_dims();
@ -24,6 +25,7 @@ static void CreateGatherNDOp(Program& p, const std::shared_ptr<ngraph::op::v5::G
auto primitive = cldnn::gather_nd(layerName,
inputPrimitives[0],
inputPrimitives[1],
input_rank,
indices_rank,
batch_dims,
true,
@ -40,6 +42,7 @@ static void CreateGatherNDOp(Program& p, const std::shared_ptr<ngraph::op::v8::G
auto inputPrimitives = p.GetInputPrimitiveIDs(op);
std::string layerName = layer_type_name_ID(op);
int32_t input_rank = static_cast<int32_t>(op->get_input_shape(0).size());
int32_t indices_rank = static_cast<int32_t>(op->get_input_shape(1).size());
auto batch_dims = op->get_batch_dims();
@ -47,6 +50,7 @@ static void CreateGatherNDOp(Program& p, const std::shared_ptr<ngraph::op::v8::G
auto primitive = cldnn::gather_nd(layerName,
inputPrimitives[0],
inputPrimitives[1],
input_rank,
indices_rank,
batch_dims,
false,

View File

@ -23,6 +23,7 @@ struct gather_nd : public primitive_base<gather_nd> {
/// @param id This primitive id.
/// @param data Input data primitive id.
/// @param indices Input indexes primitive id.
/// @param input_rank Rank of input data.
/// @param indices_rank Rank of indices.
/// @param batch_dims batch_dims as an attribute of GatherND. Optional.
/// @param batch_merged_output batched output shape is merged as a dimention for v5.
@ -32,16 +33,21 @@ struct gather_nd : public primitive_base<gather_nd> {
gather_nd(const primitive_id& id,
const primitive_id& data,
const primitive_id& indices,
const uint8_t input_rank,
const uint8_t indices_rank,
const uint8_t batch_dims = 0,
const bool batch_merged_output = true,
const primitive_id& ext_prim_id = "",
const padding& output_padding = padding())
: primitive_base(id, {data, indices}, ext_prim_id, output_padding),
input_rank(input_rank),
indices_rank(indices_rank),
batch_dims(batch_dims),
batch_merged_output(batch_merged_output) {}
/// @brief GatherND input_rank
uint8_t input_rank;
/// @brief GatherND indices_rank
uint8_t indices_rank;

View File

@ -24,8 +24,7 @@ layout gather_nd_inst::calc_output_layout(gather_nd_node const& node) {
auto input_layout = input_layout_origin.size.sizes(input_layout_origin.format);
auto indices_layout = indices_layout_origin.size.sizes(indices_layout_origin.format);
const size_t input_dims = input_layout.size();
const auto input_rank = static_cast<size_t>(op->input_rank);
const auto indices_rank = op->indices_rank;
const auto batch_dims = op->batch_dims;
@ -37,7 +36,7 @@ layout gather_nd_inst::calc_output_layout(gather_nd_node const& node) {
}
const size_t indices_last_dim = indices_layout[indices_rank - 1];
for (size_t x = static_cast<size_t>(batch_dims + indices_last_dim); x < input_dims; x++) {
for (size_t x = static_cast<size_t>(batch_dims + indices_last_dim); x < input_rank; x++) {
output_sizes.push_back(input_layout[x]);
}

View File

@ -8749,13 +8749,23 @@ public:
class gather_nd_quantize : public GatherNDPrimitiveFusingTest {};
TEST_P(gather_nd_quantize, basic) {
auto p = GetParam();
auto input_rank = 0;
if (p.input_format == format::bfyx) {
input_rank = 4;
} else if (p.input_format == format::bfzyx) {
input_rank = 5;
} else if (p.input_format == format::bfwzyx) {
input_rank = 6;
}
create_topologies(input_layout("input", get_input_layout(p)),
data("gather_nd_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices - 1)),
data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)),
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
data("out_lo", get_mem(get_single_element_layout(p), -127)),
data("out_hi", get_mem(get_single_element_layout(p), 127)),
gather_nd("gather_nd_prim", "input", "gather_nd_indices", p.indices_rank, p.batch_dims),
gather_nd("gather_nd_prim", "input", "gather_nd_indices", input_rank, p.indices_rank, p.batch_dims),
quantize("quantize", "gather_nd_prim", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8),
reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
);
@ -8802,11 +8812,20 @@ class gather_nd_activation_scale_eltwise : public GatherNDPrimitiveFusingTest {}
TEST_P(gather_nd_activation_scale_eltwise, basic) {
auto p = GetParam();
auto input_rank = 0;
if (p.input_format == format::bfyx) {
input_rank = 4;
} else if (p.input_format == format::bfzyx) {
input_rank = 5;
} else if (p.input_format == format::bfwzyx) {
input_rank = 6;
}
create_topologies(input_layout("input", get_input_layout(p)),
data("gather_nd_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices - 1)),
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / 255)),
data("eltwise_data", get_mem(get_output_layout(p))),
gather_nd("gather_nd_prim", "input", "gather_nd_indices", p.indices_rank, p.batch_dims),
gather_nd("gather_nd_prim", "input", "gather_nd_indices", input_rank, p.indices_rank, p.batch_dims),
activation("activation", "gather_nd_prim", activation_func::abs),
scale("scale", "activation", "scale_data"),
eltwise("eltwise", { "scale", "eltwise_data" }, eltwise_mode::sum, p.data_type),

View File

@ -21,7 +21,19 @@ inline void DoTestBase(engine& engine,
const tensor ts,
const bool batch_merged_output) {
topology topology;
auto gather_nd_inst = gather_nd("gather_nd", "InputData", "InputIndices", indices_rank, batch_dims, batch_merged_output);
int input_rank = 0;
if (input0->get_layout().format == format::bfyx) {
input_rank = 4;
} else if (input0->get_layout().format == format::bfzyx) {
input_rank = 5;
} else if (input0->get_layout().format == format::bfwzyx) {
input_rank = 6;
} else {
FAIL();
}
auto gather_nd_inst = gather_nd("gather_nd", "InputData", "InputIndices", input_rank, indices_rank, batch_dims, batch_merged_output);
topology.add(input_layout("InputData", input0->get_layout()));
topology.add(input_layout("InputIndices", input1->get_layout()));
topology.add(gather_nd_inst);