[GPU] Fix weight reorder bug (#15672)

This commit is contained in:
Dohyun Kim (Felix) 2023-02-23 14:48:46 +09:00 committed by GitHub
parent c749163f72
commit 1028c7b5d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 47 deletions

View File

@ -4,27 +4,28 @@
#include "common.cl" #include "common.cl"
#define GET_FILTER_OS_IS_YX_ISV16_OSV16_INDEX(prefix, o, i, y, x, sub_group_size) \ #define GET_FILTER_OS_IS_YX_ISV_OSV_INDEX(prefix, o, i, y, x, osv, isv) \
CAT(prefix, _OFFSET) + \ get_os_is_zyx_isv_osv_index( \
((o) % (sub_group_size)) + \ o, i, 0, y, x, \
(sub_group_size)*( \ CAT(prefix, _SIZE_X), \
(x)*(sub_group_size)*CAT(prefix, _X_PITCH) + \ CAT(prefix, _SIZE_Y), \
(y)*(sub_group_size)*CAT(prefix, _Y_PITCH) + \ 1, \
((i) % (sub_group_size)) + \ CAT(prefix, _IFM_NUM), \
((i) / (sub_group_size))*(sub_group_size)*CAT(prefix, _IFM_PITCH) + \ CAT(prefix, _OFM_NUM), \
((o) / (sub_group_size))*CAT(prefix, _OFM_PITCH) \ osv, \
isv \
) )
#define GET_FILTER_OS_IS_ZYX_ISV16_OSV16_INDEX(prefix, o, i, z, y, x, sub_group_size) \ #define GET_FILTER_OS_IS_ZYX_ISV_OSV_INDEX(prefix, o, i, z, y, x, osv, isv) \
CAT(prefix, _OFFSET) + \ get_os_is_zyx_isv_osv_index( \
((o) % (sub_group_size)) + \ o, i, z, y, x, \
(sub_group_size)*( \ CAT(prefix, _SIZE_X), \
(x)*(sub_group_size)*CAT(prefix, _X_PITCH) + \ CAT(prefix, _SIZE_Y), \
(y)*(sub_group_size)*CAT(prefix, _Y_PITCH) + \ CAT(prefix, _SIZE_Z), \
(z)*(sub_group_size)*CAT(prefix, _Z_PITCH) + \ CAT(prefix, _IFM_NUM), \
((i) % (sub_group_size)) + \ CAT(prefix, _OFM_NUM), \
((i) / (sub_group_size))*(sub_group_size)*CAT(prefix, _IFM_PITCH) + \ osv, \
((o) / (sub_group_size))*CAT(prefix, _OFM_PITCH) \ isv \
) )
#define GET_FILTER_IS_OS_ZYX_ISV16_OSV16_INDEX(prefix, o, i, z, y, x, sub_group_size) \ #define GET_FILTER_IS_OS_ZYX_ISV16_OSV16_INDEX(prefix, o, i, z, y, x, sub_group_size) \
@ -85,6 +86,32 @@
CAT(prefix, _OFFSET) \ CAT(prefix, _OFFSET) \
) )
inline uint get_os_is_zyx_isv_osv_index(uint o, uint i, uint z, uint y, uint x,
uint x_size, uint y_size, uint z_size, uint i_size, uint o_size, uint osv_size, uint isv_size)
{
const uint isv = i % isv_size;
const uint osv = o % osv_size;
const uint is = i / isv_size;
const uint os = o / osv_size;
const uint x_pitch = osv_size * isv_size;
const uint y_pitch = x_pitch * x_size;
const uint z_pitch = y_pitch * y_size;
const uint is_pitch = z_pitch * z_size;
const uint os_pitch = is_pitch * ((i_size + isv_size - 1) / isv_size);
const uint output_offset =
osv +
isv * osv_size +
x * x_pitch +
y * y_pitch +
z * z_pitch +
is * is_pitch +
os * os_pitch;
return output_offset;
}
inline uint get_os_is_zyx_osv_isv_index(uint o, uint i, uint z, uint y, uint x, inline uint get_os_is_zyx_osv_isv_index(uint o, uint i, uint z, uint y, uint x,
uint x_size, uint y_size, uint z_size, uint i_size, uint o_size, uint osv_size, uint isv_size) uint x_size, uint y_size, uint z_size, uint i_size, uint o_size, uint osv_size, uint isv_size)
{ {
@ -329,7 +356,7 @@ inline uint get_os_zyxi_osv16_index(uint o, uint i, uint z, uint y, uint x, uint
#define GET_FILTER_INDEX_5D_SAFE(prefix, g, o, i, z, y, x) GET_FILTER_GOIZYX_SAFE(prefix, g, o, i, z, y, x) #define GET_FILTER_INDEX_5D_SAFE(prefix, g, o, i, z, y, x) GET_FILTER_GOIZYX_SAFE(prefix, g, o, i, z, y, x)
#define GET_FILTER_OS_IYX_OSV8_INDEX(prefix, o, i, y, x, sub_group_size) \ #define GET_FILTER_OS_IYX_OSV_INDEX(prefix, o, i, y, x, sub_group_size) \
CAT(prefix, _OFFSET) + \ CAT(prefix, _OFFSET) + \
((o) % (sub_group_size)) + \ ((o) % (sub_group_size)) + \
(sub_group_size)*( \ (sub_group_size)*( \
@ -339,7 +366,7 @@ inline uint get_os_zyxi_osv16_index(uint o, uint i, uint z, uint y, uint x, uint
((o) / (sub_group_size))*CAT(prefix, _OFM_PITCH) \ ((o) / (sub_group_size))*CAT(prefix, _OFM_PITCH) \
) )
#define GET_FILTER_OS_IYX_OSV8_ROTATE_180_INDEX(prefix, o, i, y, x, sub_group_size) \ #define GET_FILTER_OS_IYX_OSV_ROTATE_180_INDEX(prefix, o, i, y, x, sub_group_size) \
CAT(prefix, _OFFSET) + \ CAT(prefix, _OFFSET) + \
((o) % (sub_group_size)) + \ ((o) % (sub_group_size)) + \
(sub_group_size)*( \ (sub_group_size)*( \
@ -1495,16 +1522,6 @@ inline uint get_os_i_yxs_osv_yxsv4_index(uint o, uint i, uint y, uint x, uint i_
CAT(prefix, _SIZE_Y), \ CAT(prefix, _SIZE_Y), \
4) 4)
#define GET_FILTER_OS_IYX_OSV32__AI32_INDEX(prefix, o, i, y, x, sub_group_size) \
CAT(prefix, _OFFSET) + \
((o) % (sub_group_size)) + \
(sub_group_size)*( \
(x)*CAT(prefix, _X_PITCH) + \
(y)*CAT(prefix, _Y_PITCH) + \
(i)*CAT(prefix, _IFM_PITCH) + \
((o) / (sub_group_size))*CAT(prefix, _OFM_PITCH) \
)
#define GET_FILTER_G_OS_IYX_OSV16(prefix, g, o, i, y, x, sub_group_size) \ #define GET_FILTER_G_OS_IYX_OSV16(prefix, g, o, i, y, x, sub_group_size) \
CAT(prefix, _OFFSET) + \ CAT(prefix, _OFFSET) + \
(g * CAT(prefix, _GROUPS_PITCH)) + \ (g * CAT(prefix, _GROUPS_PITCH)) + \

View File

@ -25,19 +25,20 @@ inline uint FUNC(get_input_index)(uint g, uint o, uint i, uint z, uint y, uint x
return GET_FILTER_INDEX_5D(INPUT0, 0, o, i, z, y, x); return GET_FILTER_INDEX_5D(INPUT0, 0, o, i, z, y, x);
#elif defined INPUT0_LAYOUT_OS_IYX_OSV16 || \ #elif defined INPUT0_LAYOUT_OS_IYX_OSV16 || \
defined INPUT0_LAYOUT_OS_I_OSV16 || \ defined INPUT0_LAYOUT_OS_I_OSV16 || \
defined INPUT0_LAYOUT_OS_I_OSV8__AI8 || \
defined INPUT0_LAYOUT_OS_I_OSV16__AI8 defined INPUT0_LAYOUT_OS_I_OSV16__AI8
return GET_FILTER_OS_IYX_OSV8_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE); return GET_FILTER_OS_IYX_OSV_INDEX(INPUT0, o, i, y, x, 16);
#elif defined INPUT0_LAYOUT_OS_I_OSV8__AI8
return GET_FILTER_OS_IYX_OSV_INDEX(INPUT0, o, i, y, x, 8);
#elif defined INPUT0_LAYOUT_IYX_OSV32 #elif defined INPUT0_LAYOUT_IYX_OSV32
return GET_FILTER_OS_IYX_OSV8_INDEX(INPUT0, o, i, y, x, 32); return GET_FILTER_OS_IYX_OSV_INDEX(INPUT0, o, i, y, x, 32);
#elif defined INPUT0_LAYOUT_OS_IYX_OSV32__AI32 #elif defined INPUT0_LAYOUT_OS_IYX_OSV32__AI32
return GET_FILTER_OS_IYX_OSV32__AI32_INDEX(OUTPUT, o, i, y, x, 32); return GET_FILTER_OS_IYX_OSV_INDEX(INPUT0, o, i, y, x, 32);
#elif defined INPUT0_LAYOUT_O_IS_YX_ISV16 #elif defined INPUT0_LAYOUT_O_IS_YX_ISV16
return GET_FILTER_O_IS_YX_ISV16_INDEX(INPUT0, o, i, y, x, 16); return GET_FILTER_O_IS_YX_ISV16_INDEX(INPUT0, o, i, y, x, 16);
#elif defined INPUT0_LAYOUT_IYX_OSV64 #elif defined INPUT0_LAYOUT_IYX_OSV64
return GET_FILTER_OS_IYX_OSV8_INDEX(INPUT0, o, i, y, x, 64); return GET_FILTER_OS_IYX_OSV_INDEX(INPUT0, o, i, y, x, 64);
#elif defined INPUT0_LAYOUT_OS_IYX_OSV16_ROTATE_180 #elif defined INPUT0_LAYOUT_OS_IYX_OSV16_ROTATE_180
return GET_FILTER_OS_IYX_OSV8_ROTATE_180_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE); return GET_FILTER_OS_IYX_OSV_ROTATE_180_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE);
#elif defined INPUT0_LAYOUT_I_YXS_OS_YXSV2_OSV16 #elif defined INPUT0_LAYOUT_I_YXS_OS_YXSV2_OSV16
return GET_FILTER_I_YXS_OS_YXSV2_OSV_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE); return GET_FILTER_I_YXS_OS_YXSV2_OSV_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE);
#elif defined INPUT0_LAYOUT_IY_XS_OS_XSV2_OSV16__AO32 || defined OUTPUT_LAYOUT_IY_XS_OS_XSV2_OSV8__AO32 #elif defined INPUT0_LAYOUT_IY_XS_OS_XSV2_OSV16__AO32 || defined OUTPUT_LAYOUT_IY_XS_OS_XSV2_OSV8__AO32
@ -61,11 +62,11 @@ inline uint FUNC(get_input_index)(uint g, uint o, uint i, uint z, uint y, uint x
#elif defined INPUT0_LAYOUT_OS_IS_Y_X8_OSV8_ISV4_SWIZZLED_BY_4 #elif defined INPUT0_LAYOUT_OS_IS_Y_X8_OSV8_ISV4_SWIZZLED_BY_4
return GET_FILTER_OS_IS_Y_X8_OSV8_ISV4_SWIZZLED_BY_4(INPUT0, o, i, y, x); return GET_FILTER_OS_IS_Y_X8_OSV8_ISV4_SWIZZLED_BY_4(INPUT0, o, i, y, x);
#elif defined INPUT0_LAYOUT_OS_IS_YX_ISV16_OSV16 #elif defined INPUT0_LAYOUT_OS_IS_YX_ISV16_OSV16
return GET_FILTER_OS_IS_YX_ISV16_OSV16_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE); return GET_FILTER_OS_IS_YX_ISV_OSV_INDEX(INPUT0, o, i, y, x, 16, 16);
#elif defined INPUT0_LAYOUT_OIYX_O16 #elif defined INPUT0_LAYOUT_OIYX_O16
return GET_FILTER_OIYX_O16(INPUT0, o, i, y, x); return GET_FILTER_OIYX_O16(INPUT0, o, i, y, x);
#elif defined INPUT0_LAYOUT_OS_IS_ZYX_ISV16_OSV16 #elif defined INPUT0_LAYOUT_OS_IS_ZYX_ISV16_OSV16
return GET_FILTER_OS_IS_ZYX_ISV16_OSV16_INDEX(INPUT0, o, i, z, y, x, SUB_GROUP_SIZE); return GET_FILTER_OS_IS_ZYX_ISV_OSV_INDEX(INPUT0, o, i, z, y, x, 16, 16);
#elif defined INPUT0_LAYOUT_IS_OS_ZYX_ISV16_OSV16 #elif defined INPUT0_LAYOUT_IS_OS_ZYX_ISV16_OSV16
return GET_FILTER_IS_OS_ZYX_ISV16_OSV16_INDEX(INPUT0, o, i, z, y, x, SUB_GROUP_SIZE); return GET_FILTER_IS_OS_ZYX_ISV16_OSV16_INDEX(INPUT0, o, i, z, y, x, SUB_GROUP_SIZE);
#elif defined INPUT0_LAYOUT_IS_OS_YX_ISV16_OSV16 #elif defined INPUT0_LAYOUT_IS_OS_YX_ISV16_OSV16
@ -219,19 +220,20 @@ inline uint FUNC(get_output_index)(uint g, uint o, uint i, uint z, uint y, uint
return GET_FILTER_INDEX_5D(OUTPUT, 0, o, i, z, y, x); return GET_FILTER_INDEX_5D(OUTPUT, 0, o, i, z, y, x);
#elif defined OUTPUT_LAYOUT_OS_IYX_OSV16 || \ #elif defined OUTPUT_LAYOUT_OS_IYX_OSV16 || \
defined OUTPUT_LAYOUT_OS_I_OSV16 || \ defined OUTPUT_LAYOUT_OS_I_OSV16 || \
defined OUTPUT_LAYOUT_OS_I_OSV8__AI8 || \
defined OUTPUT_LAYOUT_OS_I_OSV16__AI8 defined OUTPUT_LAYOUT_OS_I_OSV16__AI8
return GET_FILTER_OS_IYX_OSV8_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE); return GET_FILTER_OS_IYX_OSV_INDEX(OUTPUT, o, i, y, x, 16);
#elif defined OUTPUT_LAYOUT_OS_I_OSV8__AI8
return GET_FILTER_OS_IYX_OSV_INDEX(OUTPUT, o, i, y, x, 8);
#elif defined OUTPUT_LAYOUT_OS_IYX_OSV32 #elif defined OUTPUT_LAYOUT_OS_IYX_OSV32
return GET_FILTER_OS_IYX_OSV8_INDEX(OUTPUT, o, i, y, x, 32); return GET_FILTER_OS_IYX_OSV_INDEX(OUTPUT, o, i, y, x, 32);
#elif defined OUTPUT_LAYOUT_OS_IYX_OSV32__AI32 #elif defined OUTPUT_LAYOUT_OS_IYX_OSV32__AI32
return GET_FILTER_OS_IYX_OSV32__AI32_INDEX(OUTPUT, o, i, y, x, 32); return GET_FILTER_OS_IYX_OSV_INDEX(OUTPUT, o, i, y, x, 32);
#elif defined OUTPUT_LAYOUT_OS_IYX_OSV64 #elif defined OUTPUT_LAYOUT_OS_IYX_OSV64
return GET_FILTER_OS_IYX_OSV8_INDEX(OUTPUT, o, i, y, x, 64); return GET_FILTER_OS_IYX_OSV_INDEX(OUTPUT, o, i, y, x, 64);
#elif defined OUTPUT_LAYOUT_O_IS_YX_ISV16 #elif defined OUTPUT_LAYOUT_O_IS_YX_ISV16
return GET_FILTER_O_IS_YX_ISV16_INDEX(OUTPUT, o, i, y, x, 16); return GET_FILTER_O_IS_YX_ISV16_INDEX(OUTPUT, o, i, y, x, 16);
#elif defined OUTPUT_LAYOUT_OS_IYX_OSV16_ROTATE_180 #elif defined OUTPUT_LAYOUT_OS_IYX_OSV16_ROTATE_180
return GET_FILTER_OS_IYX_OSV8_ROTATE_180_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE); return GET_FILTER_OS_IYX_OSV_ROTATE_180_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE);
#elif defined OUTPUT_LAYOUT_I_YXS_OS_YXSV2_OSV16 #elif defined OUTPUT_LAYOUT_I_YXS_OS_YXSV2_OSV16
return GET_FILTER_I_YXS_OS_YXSV2_OSV_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE); return GET_FILTER_I_YXS_OS_YXSV2_OSV_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE);
#elif defined OUTPUT_LAYOUT_IY_XS_OS_XSV2_OSV16__AO32 || defined OUTPUT_LAYOUT_IY_XS_OS_XSV2_OSV8__AO32 #elif defined OUTPUT_LAYOUT_IY_XS_OS_XSV2_OSV16__AO32 || defined OUTPUT_LAYOUT_IY_XS_OS_XSV2_OSV8__AO32
@ -313,11 +315,11 @@ inline uint FUNC(get_output_index)(uint g, uint o, uint i, uint z, uint y, uint
#elif defined OUTPUT_LAYOUT_OS_IS_ZYX_OSA4_ISA8_OSV8_ISV4_SWIZZLED_BY_4 #elif defined OUTPUT_LAYOUT_OS_IS_ZYX_OSA4_ISA8_OSV8_ISV4_SWIZZLED_BY_4
return GET_FILTER_OS_IS_ZYX_OSA4_ISA8_OSV8_ISV4_SWIZZLED_BY_4_INDEX(OUTPUT, o, i, z, y, x); return GET_FILTER_OS_IS_ZYX_OSA4_ISA8_OSV8_ISV4_SWIZZLED_BY_4_INDEX(OUTPUT, o, i, z, y, x);
#elif defined OUTPUT_LAYOUT_OS_IS_YX_ISV16_OSV16 #elif defined OUTPUT_LAYOUT_OS_IS_YX_ISV16_OSV16
return GET_FILTER_OS_IS_YX_ISV16_OSV16_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE); return GET_FILTER_OS_IS_YX_ISV_OSV_INDEX(OUTPUT, o, i, y, x, 16, 16);
#elif defined OUTPUT_LAYOUT_OS_YXI_OSV16 #elif defined OUTPUT_LAYOUT_OS_YXI_OSV16
return GET_FILTER_OS_YXI_OSV16(OUTPUT, o, i, y, x); return GET_FILTER_OS_YXI_OSV16(OUTPUT, o, i, y, x);
#elif defined OUTPUT_LAYOUT_OS_IS_ZYX_ISV16_OSV16 #elif defined OUTPUT_LAYOUT_OS_IS_ZYX_ISV16_OSV16
return GET_FILTER_OS_IS_ZYX_ISV16_OSV16_INDEX(OUTPUT, o, i, z, y, x, SUB_GROUP_SIZE); return GET_FILTER_OS_IS_ZYX_ISV_OSV_INDEX(OUTPUT, o, i, z, y, x, 16, 16);
#elif defined OUTPUT_LAYOUT_IS_OS_ZYX_ISV16_OSV16 #elif defined OUTPUT_LAYOUT_IS_OS_ZYX_ISV16_OSV16
return GET_FILTER_IS_OS_ZYX_ISV16_OSV16_INDEX(OUTPUT, o, i, z, y, x, SUB_GROUP_SIZE); return GET_FILTER_IS_OS_ZYX_ISV16_OSV16_INDEX(OUTPUT, o, i, z, y, x, SUB_GROUP_SIZE);
#elif defined OUTPUT_LAYOUT_IS_OS_YX_ISV16_OSV16 #elif defined OUTPUT_LAYOUT_IS_OS_YX_ISV16_OSV16