[GPU] Fix GatherND shape agnostic ref kernel (#19706)

Signed-off-by: Andrew Park <andrew.park@intel.com>
This commit is contained in:
Andrew Kwangwoong Park 2023-09-11 17:20:10 +09:00 committed by GitHub
parent 530da61a4e
commit 161ba14796
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 20 deletions

View File

@ -43,21 +43,6 @@ KERNEL(gather_nd_ref)(
#endif
)
{
#if IS_DYNAMIC
uint wi_slice = 1;
#if INPUT0_DIMS == 4
uint input_dim[4] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Y, INPUT0_SIZE_X};
#elif INPUT0_DIMS == 5
uint input_dim[5] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X};
#else
uint input_dim[6] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_W, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X};
#endif
for (uint i = BATCH_DIMS + INPUT1_SIZE_X; i < INPUT0_DIMS; i++)
wi_slice *= input_dim[i];
#define WI_SLICE_SIZE wi_slice
#else
#define WI_SLICE_SIZE WI_SLICE_SIZE_STATIC
#endif
const uint dim0 = get_global_id(0);
const uint dim1 = get_global_id(1);
const uint dim2 = get_global_id(2);
@ -98,19 +83,36 @@ KERNEL(gather_nd_ref)(
const uint idx_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_W, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X};
#endif
#if IS_DYNAMIC
uint wi_slice = 1;
#if INPUT0_DIMS == 4
uint input_dims[4] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Y, INPUT0_SIZE_X};
#elif INPUT0_DIMS == 5
uint input_dims[5] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X};
#else
uint input_dims[6] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_W, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X};
#endif
const uint indices_last_dim = idx_dim[INDICES_RANK - 1];
for (uint i = BATCH_DIMS + indices_last_dim; i < INPUT0_DIMS; i++)
wi_slice *= input_dims[i];
#else
const uint wi_slice = WI_SLICE_SIZE;
const uint indices_last_dim = INDICES_LAST_DIM;
#endif
const int idx = GET_UPDATES_INDEX(INPUT1, IDX_ORDER);
// Calculate data index
uint indices_val[INDICES_MAX_DIM + BATCH_DIMS];
for (int i = 0; i < INDICES_MAX_DIM + BATCH_DIMS; i++) {
for (uint i = 0; i < INDICES_MAX_DIM + BATCH_DIMS; i++) {
indices_val[i] = 0;
}
for (int i = 0; i < BATCH_DIMS; i++) {
for (uint i = 0; i < BATCH_DIMS; i++) {
indices_val[i] = idx_arr[i];
}
for (int i = 0; i < INDICES_LAST_DIM; i++) {
for (uint i = 0; i < indices_last_dim; i++) {
indices_val[i + BATCH_DIMS] = indices[idx+i];
}
@ -204,7 +206,7 @@ KERNEL(gather_nd_ref)(
const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM;
#endif
for (int i = 0; i < WI_SLICE_SIZE; i++) {
for (uint i = 0; i < wi_slice; i++) {
uint dst_idx = output_idx + i;
INPUT0_TYPE val = data[data_idx + i];

View File

@ -119,7 +119,7 @@ JitConstants GatherNDKernelRef::GetJitConstants(const gather_nd_params& params)
jit.AddConstant(MakeJitConstant("INDICES_RANK", params.indices_rank));
jit.AddConstant(MakeJitConstant("BATCH_DIMS", params.batch_dims));
jit.AddConstant(MakeJitConstant("BATCH_MERGED_OUTPUT", params.batch_merged_output));
jit.AddConstant(MakeJitConstant("WI_SLICE_SIZE_STATIC", GetSliceSize(params)));
jit.AddConstant(MakeJitConstant("WI_SLICE_SIZE", GetSliceSize(params)));
jit.AddConstant(MakeJitConstant("INDICES_LAST_DIM", GetIndicesLastDim(params)));
if (!params.fused_ops.empty()) {

View File

@ -131,6 +131,11 @@ const std::vector<GatherNDShapeParams> dynamicInputShapeConstTargetShape = {
ov::test::InputShape(ov::PartialShape({}), {{2, 1}}),
1
},
{
ov::test::InputShape(ov::PartialShape({-1, -1}), {{10, 14}}),
ov::test::InputShape(ov::PartialShape({}), {{3, 2}}),
0
},
{
ov::test::InputShape(ov::PartialShape({-1, -1, -1}), {{2, 3, 4}, {3, 4, 5}}),
ov::test::InputShape(ov::PartialShape({}), {{2, 1}}),