[GPU] Update GatherND primitive (#8813)
* Cldnn output memory size at GatherND functional-test is aligned with TensorDesc of output blob * Add param for rank of input data * Update unittests to add rank of input data * Update gpu fusing tests
This commit is contained in:
parent
d50f20b977
commit
2d996c1354
@ -17,6 +17,7 @@ static void CreateGatherNDOp(Program& p, const std::shared_ptr<ngraph::op::v5::G
|
||||
auto inputPrimitives = p.GetInputPrimitiveIDs(op);
|
||||
std::string layerName = layer_type_name_ID(op);
|
||||
|
||||
int32_t input_rank = static_cast<int32_t>(op->get_input_shape(0).size());
|
||||
int32_t indices_rank = static_cast<int32_t>(op->get_input_shape(1).size());
|
||||
|
||||
auto batch_dims = op->get_batch_dims();
|
||||
@ -24,6 +25,7 @@ static void CreateGatherNDOp(Program& p, const std::shared_ptr<ngraph::op::v5::G
|
||||
auto primitive = cldnn::gather_nd(layerName,
|
||||
inputPrimitives[0],
|
||||
inputPrimitives[1],
|
||||
input_rank,
|
||||
indices_rank,
|
||||
batch_dims,
|
||||
true,
|
||||
@ -40,6 +42,7 @@ static void CreateGatherNDOp(Program& p, const std::shared_ptr<ngraph::op::v8::G
|
||||
auto inputPrimitives = p.GetInputPrimitiveIDs(op);
|
||||
std::string layerName = layer_type_name_ID(op);
|
||||
|
||||
int32_t input_rank = static_cast<int32_t>(op->get_input_shape(0).size());
|
||||
int32_t indices_rank = static_cast<int32_t>(op->get_input_shape(1).size());
|
||||
|
||||
auto batch_dims = op->get_batch_dims();
|
||||
@ -47,6 +50,7 @@ static void CreateGatherNDOp(Program& p, const std::shared_ptr<ngraph::op::v8::G
|
||||
auto primitive = cldnn::gather_nd(layerName,
|
||||
inputPrimitives[0],
|
||||
inputPrimitives[1],
|
||||
input_rank,
|
||||
indices_rank,
|
||||
batch_dims,
|
||||
false,
|
||||
|
@ -23,6 +23,7 @@ struct gather_nd : public primitive_base<gather_nd> {
|
||||
/// @param id This primitive id.
|
||||
/// @param data Input data primitive id.
|
||||
/// @param indices Input indexes primitive id.
|
||||
/// @param input_rank Rank of input data.
|
||||
/// @param indices_rank Rank of indices.
|
||||
/// @param batch_dims batch_dims as an attribute of GatherND. Optional.
|
||||
/// @param batch_merged_output batched output shape is merged as a dimention for v5.
|
||||
@ -32,16 +33,21 @@ struct gather_nd : public primitive_base<gather_nd> {
|
||||
gather_nd(const primitive_id& id,
|
||||
const primitive_id& data,
|
||||
const primitive_id& indices,
|
||||
const uint8_t input_rank,
|
||||
const uint8_t indices_rank,
|
||||
const uint8_t batch_dims = 0,
|
||||
const bool batch_merged_output = true,
|
||||
const primitive_id& ext_prim_id = "",
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, {data, indices}, ext_prim_id, output_padding),
|
||||
input_rank(input_rank),
|
||||
indices_rank(indices_rank),
|
||||
batch_dims(batch_dims),
|
||||
batch_merged_output(batch_merged_output) {}
|
||||
|
||||
/// @brief GatherND input_rank
|
||||
uint8_t input_rank;
|
||||
|
||||
/// @brief GatherND indices_rank
|
||||
uint8_t indices_rank;
|
||||
|
||||
|
@ -24,8 +24,7 @@ layout gather_nd_inst::calc_output_layout(gather_nd_node const& node) {
|
||||
auto input_layout = input_layout_origin.size.sizes(input_layout_origin.format);
|
||||
auto indices_layout = indices_layout_origin.size.sizes(indices_layout_origin.format);
|
||||
|
||||
const size_t input_dims = input_layout.size();
|
||||
|
||||
const auto input_rank = static_cast<size_t>(op->input_rank);
|
||||
const auto indices_rank = op->indices_rank;
|
||||
const auto batch_dims = op->batch_dims;
|
||||
|
||||
@ -37,7 +36,7 @@ layout gather_nd_inst::calc_output_layout(gather_nd_node const& node) {
|
||||
}
|
||||
|
||||
const size_t indices_last_dim = indices_layout[indices_rank - 1];
|
||||
for (size_t x = static_cast<size_t>(batch_dims + indices_last_dim); x < input_dims; x++) {
|
||||
for (size_t x = static_cast<size_t>(batch_dims + indices_last_dim); x < input_rank; x++) {
|
||||
output_sizes.push_back(input_layout[x]);
|
||||
}
|
||||
|
||||
|
@ -8749,13 +8749,23 @@ public:
|
||||
class gather_nd_quantize : public GatherNDPrimitiveFusingTest {};
|
||||
TEST_P(gather_nd_quantize, basic) {
|
||||
auto p = GetParam();
|
||||
|
||||
auto input_rank = 0;
|
||||
if (p.input_format == format::bfyx) {
|
||||
input_rank = 4;
|
||||
} else if (p.input_format == format::bfzyx) {
|
||||
input_rank = 5;
|
||||
} else if (p.input_format == format::bfwzyx) {
|
||||
input_rank = 6;
|
||||
}
|
||||
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
data("gather_nd_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices - 1)),
|
||||
data("in_lo", get_mem(get_per_channel_layout(p), min_random, 0)),
|
||||
data("in_hi", get_mem(get_per_channel_layout(p), 1, max_random)),
|
||||
data("out_lo", get_mem(get_single_element_layout(p), -127)),
|
||||
data("out_hi", get_mem(get_single_element_layout(p), 127)),
|
||||
gather_nd("gather_nd_prim", "input", "gather_nd_indices", p.indices_rank, p.batch_dims),
|
||||
gather_nd("gather_nd_prim", "input", "gather_nd_indices", input_rank, p.indices_rank, p.batch_dims),
|
||||
quantize("quantize", "gather_nd_prim", "in_lo", "in_hi", "out_lo", "out_hi", 255, data_types::i8),
|
||||
reorder("reorder_bfyx", "quantize", p.default_format, data_types::f32)
|
||||
);
|
||||
@ -8802,11 +8812,20 @@ class gather_nd_activation_scale_eltwise : public GatherNDPrimitiveFusingTest {}
|
||||
TEST_P(gather_nd_activation_scale_eltwise, basic) {
|
||||
auto p = GetParam();
|
||||
|
||||
auto input_rank = 0;
|
||||
if (p.input_format == format::bfyx) {
|
||||
input_rank = 4;
|
||||
} else if (p.input_format == format::bfzyx) {
|
||||
input_rank = 5;
|
||||
} else if (p.input_format == format::bfwzyx) {
|
||||
input_rank = 6;
|
||||
}
|
||||
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
data("gather_nd_indices", get_mem(get_indices_layout(p), 0, p.max_number_in_indices - 1)),
|
||||
data("scale_data", get_mem(get_per_channel_layout(p), 1.0f / 255)),
|
||||
data("eltwise_data", get_mem(get_output_layout(p))),
|
||||
gather_nd("gather_nd_prim", "input", "gather_nd_indices", p.indices_rank, p.batch_dims),
|
||||
gather_nd("gather_nd_prim", "input", "gather_nd_indices", input_rank, p.indices_rank, p.batch_dims),
|
||||
activation("activation", "gather_nd_prim", activation_func::abs),
|
||||
scale("scale", "activation", "scale_data"),
|
||||
eltwise("eltwise", { "scale", "eltwise_data" }, eltwise_mode::sum, p.data_type),
|
||||
|
@ -21,7 +21,19 @@ inline void DoTestBase(engine& engine,
|
||||
const tensor ts,
|
||||
const bool batch_merged_output) {
|
||||
topology topology;
|
||||
auto gather_nd_inst = gather_nd("gather_nd", "InputData", "InputIndices", indices_rank, batch_dims, batch_merged_output);
|
||||
|
||||
int input_rank = 0;
|
||||
if (input0->get_layout().format == format::bfyx) {
|
||||
input_rank = 4;
|
||||
} else if (input0->get_layout().format == format::bfzyx) {
|
||||
input_rank = 5;
|
||||
} else if (input0->get_layout().format == format::bfwzyx) {
|
||||
input_rank = 6;
|
||||
} else {
|
||||
FAIL();
|
||||
}
|
||||
|
||||
auto gather_nd_inst = gather_nd("gather_nd", "InputData", "InputIndices", input_rank, indices_rank, batch_dims, batch_merged_output);
|
||||
topology.add(input_layout("InputData", input0->get_layout()));
|
||||
topology.add(input_layout("InputIndices", input1->get_layout()));
|
||||
topology.add(gather_nd_inst);
|
||||
|
Loading…
Reference in New Issue
Block a user