[GPU] Fix GatherND shape agnostic ref kernel (#19706)
Signed-off-by: Andrew Park <andrew.park@intel.com>
This commit is contained in:
parent
530da61a4e
commit
161ba14796
@ -43,21 +43,6 @@ KERNEL(gather_nd_ref)(
|
||||
#endif
|
||||
)
|
||||
{
|
||||
#if IS_DYNAMIC
|
||||
uint wi_slice = 1;
|
||||
#if INPUT0_DIMS == 4
|
||||
uint input_dim[4] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Y, INPUT0_SIZE_X};
|
||||
#elif INPUT0_DIMS == 5
|
||||
uint input_dim[5] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X};
|
||||
#else
|
||||
uint input_dim[6] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_W, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X};
|
||||
#endif
|
||||
for (uint i = BATCH_DIMS + INPUT1_SIZE_X; i < INPUT0_DIMS; i++)
|
||||
wi_slice *= input_dim[i];
|
||||
#define WI_SLICE_SIZE wi_slice
|
||||
#else
|
||||
#define WI_SLICE_SIZE WI_SLICE_SIZE_STATIC
|
||||
#endif
|
||||
const uint dim0 = get_global_id(0);
|
||||
const uint dim1 = get_global_id(1);
|
||||
const uint dim2 = get_global_id(2);
|
||||
@ -98,19 +83,36 @@ KERNEL(gather_nd_ref)(
|
||||
const uint idx_dim[INPUT1_DIMS] = {INPUT1_BATCH_NUM, INPUT1_FEATURE_NUM, INPUT1_SIZE_W, INPUT1_SIZE_Z, INPUT1_SIZE_Y, INPUT1_SIZE_X};
|
||||
#endif
|
||||
|
||||
#if IS_DYNAMIC
|
||||
uint wi_slice = 1;
|
||||
#if INPUT0_DIMS == 4
|
||||
uint input_dims[4] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Y, INPUT0_SIZE_X};
|
||||
#elif INPUT0_DIMS == 5
|
||||
uint input_dims[5] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X};
|
||||
#else
|
||||
uint input_dims[6] = {INPUT0_BATCH_NUM, INPUT0_FEATURE_NUM, INPUT0_SIZE_W, INPUT0_SIZE_Z, INPUT0_SIZE_Y, INPUT0_SIZE_X};
|
||||
#endif
|
||||
const uint indices_last_dim = idx_dim[INDICES_RANK - 1];
|
||||
for (uint i = BATCH_DIMS + indices_last_dim; i < INPUT0_DIMS; i++)
|
||||
wi_slice *= input_dims[i];
|
||||
#else
|
||||
const uint wi_slice = WI_SLICE_SIZE;
|
||||
const uint indices_last_dim = INDICES_LAST_DIM;
|
||||
#endif
|
||||
|
||||
const int idx = GET_UPDATES_INDEX(INPUT1, IDX_ORDER);
|
||||
|
||||
// Calculate data index
|
||||
uint indices_val[INDICES_MAX_DIM + BATCH_DIMS];
|
||||
for (int i = 0; i < INDICES_MAX_DIM + BATCH_DIMS; i++) {
|
||||
for (uint i = 0; i < INDICES_MAX_DIM + BATCH_DIMS; i++) {
|
||||
indices_val[i] = 0;
|
||||
}
|
||||
|
||||
for (int i = 0; i < BATCH_DIMS; i++) {
|
||||
for (uint i = 0; i < BATCH_DIMS; i++) {
|
||||
indices_val[i] = idx_arr[i];
|
||||
}
|
||||
|
||||
for (int i = 0; i < INDICES_LAST_DIM; i++) {
|
||||
for (uint i = 0; i < indices_last_dim; i++) {
|
||||
indices_val[i + BATCH_DIMS] = indices[idx+i];
|
||||
}
|
||||
|
||||
@ -204,7 +206,7 @@ KERNEL(gather_nd_ref)(
|
||||
const uint b_pitch = f_pitch * OUTPUT_FEATURE_NUM;
|
||||
#endif
|
||||
|
||||
for (int i = 0; i < WI_SLICE_SIZE; i++) {
|
||||
for (uint i = 0; i < wi_slice; i++) {
|
||||
uint dst_idx = output_idx + i;
|
||||
INPUT0_TYPE val = data[data_idx + i];
|
||||
|
||||
|
@ -119,7 +119,7 @@ JitConstants GatherNDKernelRef::GetJitConstants(const gather_nd_params& params)
|
||||
jit.AddConstant(MakeJitConstant("INDICES_RANK", params.indices_rank));
|
||||
jit.AddConstant(MakeJitConstant("BATCH_DIMS", params.batch_dims));
|
||||
jit.AddConstant(MakeJitConstant("BATCH_MERGED_OUTPUT", params.batch_merged_output));
|
||||
jit.AddConstant(MakeJitConstant("WI_SLICE_SIZE_STATIC", GetSliceSize(params)));
|
||||
jit.AddConstant(MakeJitConstant("WI_SLICE_SIZE", GetSliceSize(params)));
|
||||
jit.AddConstant(MakeJitConstant("INDICES_LAST_DIM", GetIndicesLastDim(params)));
|
||||
|
||||
if (!params.fused_ops.empty()) {
|
||||
|
@ -131,6 +131,11 @@ const std::vector<GatherNDShapeParams> dynamicInputShapeConstTargetShape = {
|
||||
ov::test::InputShape(ov::PartialShape({}), {{2, 1}}),
|
||||
1
|
||||
},
|
||||
{
|
||||
ov::test::InputShape(ov::PartialShape({-1, -1}), {{10, 14}}),
|
||||
ov::test::InputShape(ov::PartialShape({}), {{3, 2}}),
|
||||
0
|
||||
},
|
||||
{
|
||||
ov::test::InputShape(ov::PartialShape({-1, -1, -1}), {{2, 3, 4}, {3, 4, 5}}),
|
||||
ov::test::InputShape(ov::PartialShape({}), {{2, 1}}),
|
||||
|
Loading…
Reference in New Issue
Block a user