[GPU] Fix weight reorder bug (#15672)
This commit is contained in:
parent
c749163f72
commit
1028c7b5d5
@ -4,27 +4,28 @@
|
|||||||
|
|
||||||
#include "common.cl"
|
#include "common.cl"
|
||||||
|
|
||||||
#define GET_FILTER_OS_IS_YX_ISV16_OSV16_INDEX(prefix, o, i, y, x, sub_group_size) \
|
#define GET_FILTER_OS_IS_YX_ISV_OSV_INDEX(prefix, o, i, y, x, osv, isv) \
|
||||||
CAT(prefix, _OFFSET) + \
|
get_os_is_zyx_isv_osv_index( \
|
||||||
((o) % (sub_group_size)) + \
|
o, i, 0, y, x, \
|
||||||
(sub_group_size)*( \
|
CAT(prefix, _SIZE_X), \
|
||||||
(x)*(sub_group_size)*CAT(prefix, _X_PITCH) + \
|
CAT(prefix, _SIZE_Y), \
|
||||||
(y)*(sub_group_size)*CAT(prefix, _Y_PITCH) + \
|
1, \
|
||||||
((i) % (sub_group_size)) + \
|
CAT(prefix, _IFM_NUM), \
|
||||||
((i) / (sub_group_size))*(sub_group_size)*CAT(prefix, _IFM_PITCH) + \
|
CAT(prefix, _OFM_NUM), \
|
||||||
((o) / (sub_group_size))*CAT(prefix, _OFM_PITCH) \
|
osv, \
|
||||||
|
isv \
|
||||||
)
|
)
|
||||||
|
|
||||||
#define GET_FILTER_OS_IS_ZYX_ISV16_OSV16_INDEX(prefix, o, i, z, y, x, sub_group_size) \
|
#define GET_FILTER_OS_IS_ZYX_ISV_OSV_INDEX(prefix, o, i, z, y, x, osv, isv) \
|
||||||
CAT(prefix, _OFFSET) + \
|
get_os_is_zyx_isv_osv_index( \
|
||||||
((o) % (sub_group_size)) + \
|
o, i, z, y, x, \
|
||||||
(sub_group_size)*( \
|
CAT(prefix, _SIZE_X), \
|
||||||
(x)*(sub_group_size)*CAT(prefix, _X_PITCH) + \
|
CAT(prefix, _SIZE_Y), \
|
||||||
(y)*(sub_group_size)*CAT(prefix, _Y_PITCH) + \
|
CAT(prefix, _SIZE_Z), \
|
||||||
(z)*(sub_group_size)*CAT(prefix, _Z_PITCH) + \
|
CAT(prefix, _IFM_NUM), \
|
||||||
((i) % (sub_group_size)) + \
|
CAT(prefix, _OFM_NUM), \
|
||||||
((i) / (sub_group_size))*(sub_group_size)*CAT(prefix, _IFM_PITCH) + \
|
osv, \
|
||||||
((o) / (sub_group_size))*CAT(prefix, _OFM_PITCH) \
|
isv \
|
||||||
)
|
)
|
||||||
|
|
||||||
#define GET_FILTER_IS_OS_ZYX_ISV16_OSV16_INDEX(prefix, o, i, z, y, x, sub_group_size) \
|
#define GET_FILTER_IS_OS_ZYX_ISV16_OSV16_INDEX(prefix, o, i, z, y, x, sub_group_size) \
|
||||||
@ -85,6 +86,32 @@
|
|||||||
CAT(prefix, _OFFSET) \
|
CAT(prefix, _OFFSET) \
|
||||||
)
|
)
|
||||||
|
|
||||||
|
inline uint get_os_is_zyx_isv_osv_index(uint o, uint i, uint z, uint y, uint x,
|
||||||
|
uint x_size, uint y_size, uint z_size, uint i_size, uint o_size, uint osv_size, uint isv_size)
|
||||||
|
{
|
||||||
|
const uint isv = i % isv_size;
|
||||||
|
const uint osv = o % osv_size;
|
||||||
|
const uint is = i / isv_size;
|
||||||
|
const uint os = o / osv_size;
|
||||||
|
|
||||||
|
const uint x_pitch = osv_size * isv_size;
|
||||||
|
const uint y_pitch = x_pitch * x_size;
|
||||||
|
const uint z_pitch = y_pitch * y_size;
|
||||||
|
const uint is_pitch = z_pitch * z_size;
|
||||||
|
const uint os_pitch = is_pitch * ((i_size + isv_size - 1) / isv_size);
|
||||||
|
|
||||||
|
const uint output_offset =
|
||||||
|
osv +
|
||||||
|
isv * osv_size +
|
||||||
|
x * x_pitch +
|
||||||
|
y * y_pitch +
|
||||||
|
z * z_pitch +
|
||||||
|
is * is_pitch +
|
||||||
|
os * os_pitch;
|
||||||
|
|
||||||
|
return output_offset;
|
||||||
|
}
|
||||||
|
|
||||||
inline uint get_os_is_zyx_osv_isv_index(uint o, uint i, uint z, uint y, uint x,
|
inline uint get_os_is_zyx_osv_isv_index(uint o, uint i, uint z, uint y, uint x,
|
||||||
uint x_size, uint y_size, uint z_size, uint i_size, uint o_size, uint osv_size, uint isv_size)
|
uint x_size, uint y_size, uint z_size, uint i_size, uint o_size, uint osv_size, uint isv_size)
|
||||||
{
|
{
|
||||||
@ -329,7 +356,7 @@ inline uint get_os_zyxi_osv16_index(uint o, uint i, uint z, uint y, uint x, uint
|
|||||||
|
|
||||||
#define GET_FILTER_INDEX_5D_SAFE(prefix, g, o, i, z, y, x) GET_FILTER_GOIZYX_SAFE(prefix, g, o, i, z, y, x)
|
#define GET_FILTER_INDEX_5D_SAFE(prefix, g, o, i, z, y, x) GET_FILTER_GOIZYX_SAFE(prefix, g, o, i, z, y, x)
|
||||||
|
|
||||||
#define GET_FILTER_OS_IYX_OSV8_INDEX(prefix, o, i, y, x, sub_group_size) \
|
#define GET_FILTER_OS_IYX_OSV_INDEX(prefix, o, i, y, x, sub_group_size) \
|
||||||
CAT(prefix, _OFFSET) + \
|
CAT(prefix, _OFFSET) + \
|
||||||
((o) % (sub_group_size)) + \
|
((o) % (sub_group_size)) + \
|
||||||
(sub_group_size)*( \
|
(sub_group_size)*( \
|
||||||
@ -339,7 +366,7 @@ inline uint get_os_zyxi_osv16_index(uint o, uint i, uint z, uint y, uint x, uint
|
|||||||
((o) / (sub_group_size))*CAT(prefix, _OFM_PITCH) \
|
((o) / (sub_group_size))*CAT(prefix, _OFM_PITCH) \
|
||||||
)
|
)
|
||||||
|
|
||||||
#define GET_FILTER_OS_IYX_OSV8_ROTATE_180_INDEX(prefix, o, i, y, x, sub_group_size) \
|
#define GET_FILTER_OS_IYX_OSV_ROTATE_180_INDEX(prefix, o, i, y, x, sub_group_size) \
|
||||||
CAT(prefix, _OFFSET) + \
|
CAT(prefix, _OFFSET) + \
|
||||||
((o) % (sub_group_size)) + \
|
((o) % (sub_group_size)) + \
|
||||||
(sub_group_size)*( \
|
(sub_group_size)*( \
|
||||||
@ -1495,16 +1522,6 @@ inline uint get_os_i_yxs_osv_yxsv4_index(uint o, uint i, uint y, uint x, uint i_
|
|||||||
CAT(prefix, _SIZE_Y), \
|
CAT(prefix, _SIZE_Y), \
|
||||||
4)
|
4)
|
||||||
|
|
||||||
#define GET_FILTER_OS_IYX_OSV32__AI32_INDEX(prefix, o, i, y, x, sub_group_size) \
|
|
||||||
CAT(prefix, _OFFSET) + \
|
|
||||||
((o) % (sub_group_size)) + \
|
|
||||||
(sub_group_size)*( \
|
|
||||||
(x)*CAT(prefix, _X_PITCH) + \
|
|
||||||
(y)*CAT(prefix, _Y_PITCH) + \
|
|
||||||
(i)*CAT(prefix, _IFM_PITCH) + \
|
|
||||||
((o) / (sub_group_size))*CAT(prefix, _OFM_PITCH) \
|
|
||||||
)
|
|
||||||
|
|
||||||
#define GET_FILTER_G_OS_IYX_OSV16(prefix, g, o, i, y, x, sub_group_size) \
|
#define GET_FILTER_G_OS_IYX_OSV16(prefix, g, o, i, y, x, sub_group_size) \
|
||||||
CAT(prefix, _OFFSET) + \
|
CAT(prefix, _OFFSET) + \
|
||||||
(g * CAT(prefix, _GROUPS_PITCH)) + \
|
(g * CAT(prefix, _GROUPS_PITCH)) + \
|
||||||
|
@ -25,19 +25,20 @@ inline uint FUNC(get_input_index)(uint g, uint o, uint i, uint z, uint y, uint x
|
|||||||
return GET_FILTER_INDEX_5D(INPUT0, 0, o, i, z, y, x);
|
return GET_FILTER_INDEX_5D(INPUT0, 0, o, i, z, y, x);
|
||||||
#elif defined INPUT0_LAYOUT_OS_IYX_OSV16 || \
|
#elif defined INPUT0_LAYOUT_OS_IYX_OSV16 || \
|
||||||
defined INPUT0_LAYOUT_OS_I_OSV16 || \
|
defined INPUT0_LAYOUT_OS_I_OSV16 || \
|
||||||
defined INPUT0_LAYOUT_OS_I_OSV8__AI8 || \
|
|
||||||
defined INPUT0_LAYOUT_OS_I_OSV16__AI8
|
defined INPUT0_LAYOUT_OS_I_OSV16__AI8
|
||||||
return GET_FILTER_OS_IYX_OSV8_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE);
|
return GET_FILTER_OS_IYX_OSV_INDEX(INPUT0, o, i, y, x, 16);
|
||||||
|
#elif defined INPUT0_LAYOUT_OS_I_OSV8__AI8
|
||||||
|
return GET_FILTER_OS_IYX_OSV_INDEX(INPUT0, o, i, y, x, 8);
|
||||||
#elif defined INPUT0_LAYOUT_IYX_OSV32
|
#elif defined INPUT0_LAYOUT_IYX_OSV32
|
||||||
return GET_FILTER_OS_IYX_OSV8_INDEX(INPUT0, o, i, y, x, 32);
|
return GET_FILTER_OS_IYX_OSV_INDEX(INPUT0, o, i, y, x, 32);
|
||||||
#elif defined INPUT0_LAYOUT_OS_IYX_OSV32__AI32
|
#elif defined INPUT0_LAYOUT_OS_IYX_OSV32__AI32
|
||||||
return GET_FILTER_OS_IYX_OSV32__AI32_INDEX(OUTPUT, o, i, y, x, 32);
|
return GET_FILTER_OS_IYX_OSV_INDEX(INPUT0, o, i, y, x, 32);
|
||||||
#elif defined INPUT0_LAYOUT_O_IS_YX_ISV16
|
#elif defined INPUT0_LAYOUT_O_IS_YX_ISV16
|
||||||
return GET_FILTER_O_IS_YX_ISV16_INDEX(INPUT0, o, i, y, x, 16);
|
return GET_FILTER_O_IS_YX_ISV16_INDEX(INPUT0, o, i, y, x, 16);
|
||||||
#elif defined INPUT0_LAYOUT_IYX_OSV64
|
#elif defined INPUT0_LAYOUT_IYX_OSV64
|
||||||
return GET_FILTER_OS_IYX_OSV8_INDEX(INPUT0, o, i, y, x, 64);
|
return GET_FILTER_OS_IYX_OSV_INDEX(INPUT0, o, i, y, x, 64);
|
||||||
#elif defined INPUT0_LAYOUT_OS_IYX_OSV16_ROTATE_180
|
#elif defined INPUT0_LAYOUT_OS_IYX_OSV16_ROTATE_180
|
||||||
return GET_FILTER_OS_IYX_OSV8_ROTATE_180_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE);
|
return GET_FILTER_OS_IYX_OSV_ROTATE_180_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE);
|
||||||
#elif defined INPUT0_LAYOUT_I_YXS_OS_YXSV2_OSV16
|
#elif defined INPUT0_LAYOUT_I_YXS_OS_YXSV2_OSV16
|
||||||
return GET_FILTER_I_YXS_OS_YXSV2_OSV_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE);
|
return GET_FILTER_I_YXS_OS_YXSV2_OSV_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE);
|
||||||
#elif defined INPUT0_LAYOUT_IY_XS_OS_XSV2_OSV16__AO32 || defined OUTPUT_LAYOUT_IY_XS_OS_XSV2_OSV8__AO32
|
#elif defined INPUT0_LAYOUT_IY_XS_OS_XSV2_OSV16__AO32 || defined OUTPUT_LAYOUT_IY_XS_OS_XSV2_OSV8__AO32
|
||||||
@ -61,11 +62,11 @@ inline uint FUNC(get_input_index)(uint g, uint o, uint i, uint z, uint y, uint x
|
|||||||
#elif defined INPUT0_LAYOUT_OS_IS_Y_X8_OSV8_ISV4_SWIZZLED_BY_4
|
#elif defined INPUT0_LAYOUT_OS_IS_Y_X8_OSV8_ISV4_SWIZZLED_BY_4
|
||||||
return GET_FILTER_OS_IS_Y_X8_OSV8_ISV4_SWIZZLED_BY_4(INPUT0, o, i, y, x);
|
return GET_FILTER_OS_IS_Y_X8_OSV8_ISV4_SWIZZLED_BY_4(INPUT0, o, i, y, x);
|
||||||
#elif defined INPUT0_LAYOUT_OS_IS_YX_ISV16_OSV16
|
#elif defined INPUT0_LAYOUT_OS_IS_YX_ISV16_OSV16
|
||||||
return GET_FILTER_OS_IS_YX_ISV16_OSV16_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE);
|
return GET_FILTER_OS_IS_YX_ISV_OSV_INDEX(INPUT0, o, i, y, x, 16, 16);
|
||||||
#elif defined INPUT0_LAYOUT_OIYX_O16
|
#elif defined INPUT0_LAYOUT_OIYX_O16
|
||||||
return GET_FILTER_OIYX_O16(INPUT0, o, i, y, x);
|
return GET_FILTER_OIYX_O16(INPUT0, o, i, y, x);
|
||||||
#elif defined INPUT0_LAYOUT_OS_IS_ZYX_ISV16_OSV16
|
#elif defined INPUT0_LAYOUT_OS_IS_ZYX_ISV16_OSV16
|
||||||
return GET_FILTER_OS_IS_ZYX_ISV16_OSV16_INDEX(INPUT0, o, i, z, y, x, SUB_GROUP_SIZE);
|
return GET_FILTER_OS_IS_ZYX_ISV_OSV_INDEX(INPUT0, o, i, z, y, x, 16, 16);
|
||||||
#elif defined INPUT0_LAYOUT_IS_OS_ZYX_ISV16_OSV16
|
#elif defined INPUT0_LAYOUT_IS_OS_ZYX_ISV16_OSV16
|
||||||
return GET_FILTER_IS_OS_ZYX_ISV16_OSV16_INDEX(INPUT0, o, i, z, y, x, SUB_GROUP_SIZE);
|
return GET_FILTER_IS_OS_ZYX_ISV16_OSV16_INDEX(INPUT0, o, i, z, y, x, SUB_GROUP_SIZE);
|
||||||
#elif defined INPUT0_LAYOUT_IS_OS_YX_ISV16_OSV16
|
#elif defined INPUT0_LAYOUT_IS_OS_YX_ISV16_OSV16
|
||||||
@ -219,19 +220,20 @@ inline uint FUNC(get_output_index)(uint g, uint o, uint i, uint z, uint y, uint
|
|||||||
return GET_FILTER_INDEX_5D(OUTPUT, 0, o, i, z, y, x);
|
return GET_FILTER_INDEX_5D(OUTPUT, 0, o, i, z, y, x);
|
||||||
#elif defined OUTPUT_LAYOUT_OS_IYX_OSV16 || \
|
#elif defined OUTPUT_LAYOUT_OS_IYX_OSV16 || \
|
||||||
defined OUTPUT_LAYOUT_OS_I_OSV16 || \
|
defined OUTPUT_LAYOUT_OS_I_OSV16 || \
|
||||||
defined OUTPUT_LAYOUT_OS_I_OSV8__AI8 || \
|
|
||||||
defined OUTPUT_LAYOUT_OS_I_OSV16__AI8
|
defined OUTPUT_LAYOUT_OS_I_OSV16__AI8
|
||||||
return GET_FILTER_OS_IYX_OSV8_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE);
|
return GET_FILTER_OS_IYX_OSV_INDEX(OUTPUT, o, i, y, x, 16);
|
||||||
|
#elif defined OUTPUT_LAYOUT_OS_I_OSV8__AI8
|
||||||
|
return GET_FILTER_OS_IYX_OSV_INDEX(OUTPUT, o, i, y, x, 8);
|
||||||
#elif defined OUTPUT_LAYOUT_OS_IYX_OSV32
|
#elif defined OUTPUT_LAYOUT_OS_IYX_OSV32
|
||||||
return GET_FILTER_OS_IYX_OSV8_INDEX(OUTPUT, o, i, y, x, 32);
|
return GET_FILTER_OS_IYX_OSV_INDEX(OUTPUT, o, i, y, x, 32);
|
||||||
#elif defined OUTPUT_LAYOUT_OS_IYX_OSV32__AI32
|
#elif defined OUTPUT_LAYOUT_OS_IYX_OSV32__AI32
|
||||||
return GET_FILTER_OS_IYX_OSV32__AI32_INDEX(OUTPUT, o, i, y, x, 32);
|
return GET_FILTER_OS_IYX_OSV_INDEX(OUTPUT, o, i, y, x, 32);
|
||||||
#elif defined OUTPUT_LAYOUT_OS_IYX_OSV64
|
#elif defined OUTPUT_LAYOUT_OS_IYX_OSV64
|
||||||
return GET_FILTER_OS_IYX_OSV8_INDEX(OUTPUT, o, i, y, x, 64);
|
return GET_FILTER_OS_IYX_OSV_INDEX(OUTPUT, o, i, y, x, 64);
|
||||||
#elif defined OUTPUT_LAYOUT_O_IS_YX_ISV16
|
#elif defined OUTPUT_LAYOUT_O_IS_YX_ISV16
|
||||||
return GET_FILTER_O_IS_YX_ISV16_INDEX(OUTPUT, o, i, y, x, 16);
|
return GET_FILTER_O_IS_YX_ISV16_INDEX(OUTPUT, o, i, y, x, 16);
|
||||||
#elif defined OUTPUT_LAYOUT_OS_IYX_OSV16_ROTATE_180
|
#elif defined OUTPUT_LAYOUT_OS_IYX_OSV16_ROTATE_180
|
||||||
return GET_FILTER_OS_IYX_OSV8_ROTATE_180_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE);
|
return GET_FILTER_OS_IYX_OSV_ROTATE_180_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE);
|
||||||
#elif defined OUTPUT_LAYOUT_I_YXS_OS_YXSV2_OSV16
|
#elif defined OUTPUT_LAYOUT_I_YXS_OS_YXSV2_OSV16
|
||||||
return GET_FILTER_I_YXS_OS_YXSV2_OSV_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE);
|
return GET_FILTER_I_YXS_OS_YXSV2_OSV_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE);
|
||||||
#elif defined OUTPUT_LAYOUT_IY_XS_OS_XSV2_OSV16__AO32 || defined OUTPUT_LAYOUT_IY_XS_OS_XSV2_OSV8__AO32
|
#elif defined OUTPUT_LAYOUT_IY_XS_OS_XSV2_OSV16__AO32 || defined OUTPUT_LAYOUT_IY_XS_OS_XSV2_OSV8__AO32
|
||||||
@ -313,11 +315,11 @@ inline uint FUNC(get_output_index)(uint g, uint o, uint i, uint z, uint y, uint
|
|||||||
#elif defined OUTPUT_LAYOUT_OS_IS_ZYX_OSA4_ISA8_OSV8_ISV4_SWIZZLED_BY_4
|
#elif defined OUTPUT_LAYOUT_OS_IS_ZYX_OSA4_ISA8_OSV8_ISV4_SWIZZLED_BY_4
|
||||||
return GET_FILTER_OS_IS_ZYX_OSA4_ISA8_OSV8_ISV4_SWIZZLED_BY_4_INDEX(OUTPUT, o, i, z, y, x);
|
return GET_FILTER_OS_IS_ZYX_OSA4_ISA8_OSV8_ISV4_SWIZZLED_BY_4_INDEX(OUTPUT, o, i, z, y, x);
|
||||||
#elif defined OUTPUT_LAYOUT_OS_IS_YX_ISV16_OSV16
|
#elif defined OUTPUT_LAYOUT_OS_IS_YX_ISV16_OSV16
|
||||||
return GET_FILTER_OS_IS_YX_ISV16_OSV16_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE);
|
return GET_FILTER_OS_IS_YX_ISV_OSV_INDEX(OUTPUT, o, i, y, x, 16, 16);
|
||||||
#elif defined OUTPUT_LAYOUT_OS_YXI_OSV16
|
#elif defined OUTPUT_LAYOUT_OS_YXI_OSV16
|
||||||
return GET_FILTER_OS_YXI_OSV16(OUTPUT, o, i, y, x);
|
return GET_FILTER_OS_YXI_OSV16(OUTPUT, o, i, y, x);
|
||||||
#elif defined OUTPUT_LAYOUT_OS_IS_ZYX_ISV16_OSV16
|
#elif defined OUTPUT_LAYOUT_OS_IS_ZYX_ISV16_OSV16
|
||||||
return GET_FILTER_OS_IS_ZYX_ISV16_OSV16_INDEX(OUTPUT, o, i, z, y, x, SUB_GROUP_SIZE);
|
return GET_FILTER_OS_IS_ZYX_ISV_OSV_INDEX(OUTPUT, o, i, z, y, x, 16, 16);
|
||||||
#elif defined OUTPUT_LAYOUT_IS_OS_ZYX_ISV16_OSV16
|
#elif defined OUTPUT_LAYOUT_IS_OS_ZYX_ISV16_OSV16
|
||||||
return GET_FILTER_IS_OS_ZYX_ISV16_OSV16_INDEX(OUTPUT, o, i, z, y, x, SUB_GROUP_SIZE);
|
return GET_FILTER_IS_OS_ZYX_ISV16_OSV16_INDEX(OUTPUT, o, i, z, y, x, SUB_GROUP_SIZE);
|
||||||
#elif defined OUTPUT_LAYOUT_IS_OS_YX_ISV16_OSV16
|
#elif defined OUTPUT_LAYOUT_IS_OS_YX_ISV16_OSV16
|
||||||
|
Loading…
Reference in New Issue
Block a user