From 1028c7b5d5479cb58216d88b6dd3dec1bb45ea33 Mon Sep 17 00:00:00 2001 From: "Dohyun Kim (Felix)" Date: Thu, 23 Feb 2023 14:48:46 +0900 Subject: [PATCH] [GPU] Fix weight reorder bug (#15672) --- .../include/batch_headers/fetch_weights.cl | 79 +++++++++++-------- .../cl_kernels/reorder_weights.cl | 34 ++++---- 2 files changed, 66 insertions(+), 47 deletions(-) diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/include/batch_headers/fetch_weights.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/include/batch_headers/fetch_weights.cl index 3dc30c7a88f..d7c86c5bed3 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/include/batch_headers/fetch_weights.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/include/batch_headers/fetch_weights.cl @@ -4,27 +4,28 @@ #include "common.cl" -#define GET_FILTER_OS_IS_YX_ISV16_OSV16_INDEX(prefix, o, i, y, x, sub_group_size) \ - CAT(prefix, _OFFSET) + \ - ((o) % (sub_group_size)) + \ - (sub_group_size)*( \ - (x)*(sub_group_size)*CAT(prefix, _X_PITCH) + \ - (y)*(sub_group_size)*CAT(prefix, _Y_PITCH) + \ - ((i) % (sub_group_size)) + \ - ((i) / (sub_group_size))*(sub_group_size)*CAT(prefix, _IFM_PITCH) + \ - ((o) / (sub_group_size))*CAT(prefix, _OFM_PITCH) \ +#define GET_FILTER_OS_IS_YX_ISV_OSV_INDEX(prefix, o, i, y, x, osv, isv) \ + get_os_is_zyx_isv_osv_index( \ + o, i, 0, y, x, \ + CAT(prefix, _SIZE_X), \ + CAT(prefix, _SIZE_Y), \ + 1, \ + CAT(prefix, _IFM_NUM), \ + CAT(prefix, _OFM_NUM), \ + osv, \ + isv \ ) -#define GET_FILTER_OS_IS_ZYX_ISV16_OSV16_INDEX(prefix, o, i, z, y, x, sub_group_size) \ - CAT(prefix, _OFFSET) + \ - ((o) % (sub_group_size)) + \ - (sub_group_size)*( \ - (x)*(sub_group_size)*CAT(prefix, _X_PITCH) + \ - (y)*(sub_group_size)*CAT(prefix, _Y_PITCH) + \ - (z)*(sub_group_size)*CAT(prefix, _Z_PITCH) + \ - ((i) % (sub_group_size)) + \ - ((i) / (sub_group_size))*(sub_group_size)*CAT(prefix, _IFM_PITCH) + \ - ((o) / (sub_group_size))*CAT(prefix, _OFM_PITCH) \ +#define GET_FILTER_OS_IS_ZYX_ISV_OSV_INDEX(prefix, o, i, z, y, x, osv, isv) \ + get_os_is_zyx_isv_osv_index( \ + o, i, z, y, x, \ + CAT(prefix, _SIZE_X), \ + CAT(prefix, _SIZE_Y), \ + CAT(prefix, _SIZE_Z), \ + CAT(prefix, _IFM_NUM), \ + CAT(prefix, _OFM_NUM), \ + osv, \ + isv \ ) #define GET_FILTER_IS_OS_ZYX_ISV16_OSV16_INDEX(prefix, o, i, z, y, x, sub_group_size) \ @@ -85,6 +86,32 @@ CAT(prefix, _OFFSET) \ ) +inline uint get_os_is_zyx_isv_osv_index(uint o, uint i, uint z, uint y, uint x, + uint x_size, uint y_size, uint z_size, uint i_size, uint o_size, uint osv_size, uint isv_size) +{ + const uint isv = i % isv_size; + const uint osv = o % osv_size; + const uint is = i / isv_size; + const uint os = o / osv_size; + + const uint x_pitch = osv_size * isv_size; + const uint y_pitch = x_pitch * x_size; + const uint z_pitch = y_pitch * y_size; + const uint is_pitch = z_pitch * z_size; + const uint os_pitch = is_pitch * ((i_size + isv_size - 1) / isv_size); + + const uint output_offset = + osv + + isv * osv_size + + x * x_pitch + + y * y_pitch + + z * z_pitch + + is * is_pitch + + os * os_pitch; + + return output_offset; +} + inline uint get_os_is_zyx_osv_isv_index(uint o, uint i, uint z, uint y, uint x, uint x_size, uint y_size, uint z_size, uint i_size, uint o_size, uint osv_size, uint isv_size) { @@ -329,7 +356,7 @@ inline uint get_os_zyxi_osv16_index(uint o, uint i, uint z, uint y, uint x, uint #define GET_FILTER_INDEX_5D_SAFE(prefix, g, o, i, z, y, x) GET_FILTER_GOIZYX_SAFE(prefix, g, o, i, z, y, x) -#define GET_FILTER_OS_IYX_OSV8_INDEX(prefix, o, i, y, x, sub_group_size) \ +#define GET_FILTER_OS_IYX_OSV_INDEX(prefix, o, i, y, x, sub_group_size) \ CAT(prefix, _OFFSET) + \ ((o) % (sub_group_size)) + \ (sub_group_size)*( \ @@ -339,7 +366,7 @@ inline uint get_os_zyxi_osv16_index(uint o, uint i, uint z, uint y, uint x, uint ((o) / (sub_group_size))*CAT(prefix, _OFM_PITCH) \ ) -#define GET_FILTER_OS_IYX_OSV8_ROTATE_180_INDEX(prefix, o, i, y, x, sub_group_size) \ +#define GET_FILTER_OS_IYX_OSV_ROTATE_180_INDEX(prefix, o, i, y, x, sub_group_size) \ CAT(prefix, _OFFSET) + \ ((o) % (sub_group_size)) + \ (sub_group_size)*( \ @@ -1495,16 +1522,6 @@ inline uint get_os_i_yxs_osv_yxsv4_index(uint o, uint i, uint y, uint x, uint i_ CAT(prefix, _SIZE_Y), \ 4) -#define GET_FILTER_OS_IYX_OSV32__AI32_INDEX(prefix, o, i, y, x, sub_group_size) \ - CAT(prefix, _OFFSET) + \ - ((o) % (sub_group_size)) + \ - (sub_group_size)*( \ - (x)*CAT(prefix, _X_PITCH) + \ - (y)*CAT(prefix, _Y_PITCH) + \ - (i)*CAT(prefix, _IFM_PITCH) + \ - ((o) / (sub_group_size))*CAT(prefix, _OFM_PITCH) \ - ) - #define GET_FILTER_G_OS_IYX_OSV16(prefix, g, o, i, y, x, sub_group_size) \ CAT(prefix, _OFFSET) + \ (g * CAT(prefix, _GROUPS_PITCH)) + \ diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/reorder_weights.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/reorder_weights.cl index 582c2f6c6c7..147ab43e837 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/reorder_weights.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/reorder_weights.cl @@ -25,19 +25,20 @@ inline uint FUNC(get_input_index)(uint g, uint o, uint i, uint z, uint y, uint x return GET_FILTER_INDEX_5D(INPUT0, 0, o, i, z, y, x); #elif defined INPUT0_LAYOUT_OS_IYX_OSV16 || \ defined INPUT0_LAYOUT_OS_I_OSV16 || \ - defined INPUT0_LAYOUT_OS_I_OSV8__AI8 || \ defined INPUT0_LAYOUT_OS_I_OSV16__AI8 - return GET_FILTER_OS_IYX_OSV8_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE); + return GET_FILTER_OS_IYX_OSV_INDEX(INPUT0, o, i, y, x, 16); +#elif defined INPUT0_LAYOUT_OS_I_OSV8__AI8 + return GET_FILTER_OS_IYX_OSV_INDEX(INPUT0, o, i, y, x, 8); #elif defined INPUT0_LAYOUT_IYX_OSV32 - return GET_FILTER_OS_IYX_OSV8_INDEX(INPUT0, o, i, y, x, 32); + return GET_FILTER_OS_IYX_OSV_INDEX(INPUT0, o, i, y, x, 32); #elif defined INPUT0_LAYOUT_OS_IYX_OSV32__AI32 - return GET_FILTER_OS_IYX_OSV32__AI32_INDEX(OUTPUT, o, i, y, x, 32); + return GET_FILTER_OS_IYX_OSV_INDEX(INPUT0, o, i, y, x, 32); #elif defined INPUT0_LAYOUT_O_IS_YX_ISV16 return GET_FILTER_O_IS_YX_ISV16_INDEX(INPUT0, o, i, y, x, 16); #elif defined INPUT0_LAYOUT_IYX_OSV64 - return GET_FILTER_OS_IYX_OSV8_INDEX(INPUT0, o, i, y, x, 64); + return GET_FILTER_OS_IYX_OSV_INDEX(INPUT0, o, i, y, x, 64); #elif defined INPUT0_LAYOUT_OS_IYX_OSV16_ROTATE_180 - return GET_FILTER_OS_IYX_OSV8_ROTATE_180_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE); + return GET_FILTER_OS_IYX_OSV_ROTATE_180_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE); #elif defined INPUT0_LAYOUT_I_YXS_OS_YXSV2_OSV16 return GET_FILTER_I_YXS_OS_YXSV2_OSV_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE); #elif defined INPUT0_LAYOUT_IY_XS_OS_XSV2_OSV16__AO32 || defined OUTPUT_LAYOUT_IY_XS_OS_XSV2_OSV8__AO32 @@ -61,11 +62,11 @@ inline uint FUNC(get_input_index)(uint g, uint o, uint i, uint z, uint y, uint x #elif defined INPUT0_LAYOUT_OS_IS_Y_X8_OSV8_ISV4_SWIZZLED_BY_4 return GET_FILTER_OS_IS_Y_X8_OSV8_ISV4_SWIZZLED_BY_4(INPUT0, o, i, y, x); #elif defined INPUT0_LAYOUT_OS_IS_YX_ISV16_OSV16 - return GET_FILTER_OS_IS_YX_ISV16_OSV16_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE); + return GET_FILTER_OS_IS_YX_ISV_OSV_INDEX(INPUT0, o, i, y, x, 16, 16); #elif defined INPUT0_LAYOUT_OIYX_O16 return GET_FILTER_OIYX_O16(INPUT0, o, i, y, x); #elif defined INPUT0_LAYOUT_OS_IS_ZYX_ISV16_OSV16 - return GET_FILTER_OS_IS_ZYX_ISV16_OSV16_INDEX(INPUT0, o, i, z, y, x, SUB_GROUP_SIZE); + return GET_FILTER_OS_IS_ZYX_ISV_OSV_INDEX(INPUT0, o, i, z, y, x, 16, 16); #elif defined INPUT0_LAYOUT_IS_OS_ZYX_ISV16_OSV16 return GET_FILTER_IS_OS_ZYX_ISV16_OSV16_INDEX(INPUT0, o, i, z, y, x, SUB_GROUP_SIZE); #elif defined INPUT0_LAYOUT_IS_OS_YX_ISV16_OSV16 @@ -219,19 +220,20 @@ inline uint FUNC(get_output_index)(uint g, uint o, uint i, uint z, uint y, uint return GET_FILTER_INDEX_5D(OUTPUT, 0, o, i, z, y, x); #elif defined OUTPUT_LAYOUT_OS_IYX_OSV16 || \ defined OUTPUT_LAYOUT_OS_I_OSV16 || \ - defined OUTPUT_LAYOUT_OS_I_OSV8__AI8 || \ defined OUTPUT_LAYOUT_OS_I_OSV16__AI8 - return GET_FILTER_OS_IYX_OSV8_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE); + return GET_FILTER_OS_IYX_OSV_INDEX(OUTPUT, o, i, y, x, 16); +#elif defined OUTPUT_LAYOUT_OS_I_OSV8__AI8 + return GET_FILTER_OS_IYX_OSV_INDEX(OUTPUT, o, i, y, x, 8); #elif defined OUTPUT_LAYOUT_OS_IYX_OSV32 - return GET_FILTER_OS_IYX_OSV8_INDEX(OUTPUT, o, i, y, x, 32); + return GET_FILTER_OS_IYX_OSV_INDEX(OUTPUT, o, i, y, x, 32); #elif defined OUTPUT_LAYOUT_OS_IYX_OSV32__AI32 - return GET_FILTER_OS_IYX_OSV32__AI32_INDEX(OUTPUT, o, i, y, x, 32); + return GET_FILTER_OS_IYX_OSV_INDEX(OUTPUT, o, i, y, x, 32); #elif defined OUTPUT_LAYOUT_OS_IYX_OSV64 - return GET_FILTER_OS_IYX_OSV8_INDEX(OUTPUT, o, i, y, x, 64); + return GET_FILTER_OS_IYX_OSV_INDEX(OUTPUT, o, i, y, x, 64); #elif defined OUTPUT_LAYOUT_O_IS_YX_ISV16 return GET_FILTER_O_IS_YX_ISV16_INDEX(OUTPUT, o, i, y, x, 16); #elif defined OUTPUT_LAYOUT_OS_IYX_OSV16_ROTATE_180 - return GET_FILTER_OS_IYX_OSV8_ROTATE_180_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE); + return GET_FILTER_OS_IYX_OSV_ROTATE_180_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE); #elif defined OUTPUT_LAYOUT_I_YXS_OS_YXSV2_OSV16 return GET_FILTER_I_YXS_OS_YXSV2_OSV_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE); #elif defined OUTPUT_LAYOUT_IY_XS_OS_XSV2_OSV16__AO32 || defined OUTPUT_LAYOUT_IY_XS_OS_XSV2_OSV8__AO32 @@ -313,11 +315,11 @@ inline uint FUNC(get_output_index)(uint g, uint o, uint i, uint z, uint y, uint #elif defined OUTPUT_LAYOUT_OS_IS_ZYX_OSA4_ISA8_OSV8_ISV4_SWIZZLED_BY_4 return GET_FILTER_OS_IS_ZYX_OSA4_ISA8_OSV8_ISV4_SWIZZLED_BY_4_INDEX(OUTPUT, o, i, z, y, x); #elif defined OUTPUT_LAYOUT_OS_IS_YX_ISV16_OSV16 - return GET_FILTER_OS_IS_YX_ISV16_OSV16_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE); + return GET_FILTER_OS_IS_YX_ISV_OSV_INDEX(OUTPUT, o, i, y, x, 16, 16); #elif defined OUTPUT_LAYOUT_OS_YXI_OSV16 return GET_FILTER_OS_YXI_OSV16(OUTPUT, o, i, y, x); #elif defined OUTPUT_LAYOUT_OS_IS_ZYX_ISV16_OSV16 - return GET_FILTER_OS_IS_ZYX_ISV16_OSV16_INDEX(OUTPUT, o, i, z, y, x, SUB_GROUP_SIZE); + return GET_FILTER_OS_IS_ZYX_ISV_OSV_INDEX(OUTPUT, o, i, z, y, x, 16, 16); #elif defined OUTPUT_LAYOUT_IS_OS_ZYX_ISV16_OSV16 return GET_FILTER_IS_OS_ZYX_ISV16_OSV16_INDEX(OUTPUT, o, i, z, y, x, SUB_GROUP_SIZE); #elif defined OUTPUT_LAYOUT_IS_OS_YX_ISV16_OSV16