[GPU] Added some formats for pvc (#19388)

This commit is contained in:
Vladimir Paramuzov 2023-08-25 15:09:42 +04:00 committed by GitHub
parent 1bdf4f0ab9
commit a45e5e03c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 118 additions and 0 deletions

View File

@ -146,6 +146,8 @@ struct format {
is_os_zyx_isv16_osv16, ///< format used for weights for blocked 3D deconvolution is_os_zyx_isv16_osv16, ///< format used for weights for blocked 3D deconvolution
is_os_yx_isv16_osv16, ///< format used for weights for blocked deconvolution is_os_yx_isv16_osv16, ///< format used for weights for blocked deconvolution
is_os_yx_isv16_osv8, ///< format used for weights for blocked deconvolution is_os_yx_isv16_osv8, ///< format used for weights for blocked deconvolution
is_os_yx_isv16_osv4, ///< format used for weights for blocked deconvolution
is_os_yx_isv16_osv2, ///< format used for weights for blocked deconvolution
os_is_yx_isv8_osv16_isv2, ///< format used for weights for blocked 2D convolution os_is_yx_isv8_osv16_isv2, ///< format used for weights for blocked 2D convolution
os_is_zyx_isv8_osv16_isv2, ///< format used for weights for blocked 3D convolution os_is_zyx_isv8_osv16_isv2, ///< format used for weights for blocked 3D convolution
///< os - output feature maps slice, i - input feature maps, ///< os - output feature maps slice, i - input feature maps,
@ -188,6 +190,7 @@ struct format {
os_is_yx_isa8_osv8_isv2, os_is_yx_isa8_osv8_isv2,
is_os_yx_isa8_osv8_isv2, is_os_yx_isa8_osv8_isv2,
is_os_yx_isa8_osv8_isv4, is_os_yx_isa8_osv8_isv4,
is_os_yx_osa8_isv16_osv4,
os_is_zyx_isa8_osv8_isv2, os_is_zyx_isa8_osv8_isv2,
is_os_zyx_isa8_osv8_isv2, is_os_zyx_isa8_osv8_isv2,
is_os_zyx_isa8_osv8_isv4, is_os_zyx_isa8_osv8_isv4,

View File

@ -524,6 +524,10 @@ kernel_selector::weights_layout to_weights_layout(format f, bool is_grouped) {
return kernel_selector::weights_layout::is_os_yx_isv16_osv16; return kernel_selector::weights_layout::is_os_yx_isv16_osv16;
case format::is_os_yx_isv16_osv8: case format::is_os_yx_isv16_osv8:
return kernel_selector::weights_layout::is_os_yx_isv16_osv8; return kernel_selector::weights_layout::is_os_yx_isv16_osv8;
case format::is_os_yx_isv16_osv4:
return kernel_selector::weights_layout::is_os_yx_isv16_osv4;
case format::is_os_yx_isv16_osv2:
return kernel_selector::weights_layout::is_os_yx_isv16_osv2;
case format::i_yxs_os_yxsv2_osv16: case format::i_yxs_os_yxsv2_osv16:
return kernel_selector::weights_layout::i_yxs_os_yxsv2_osv16; return kernel_selector::weights_layout::i_yxs_os_yxsv2_osv16;
case format::is_os_yx_osa4_isa8_osv8_isv4: case format::is_os_yx_osa4_isa8_osv8_isv4:
@ -640,6 +644,8 @@ kernel_selector::weights_layout to_weights_layout(format f, bool is_grouped) {
return kernel_selector::weights_layout::is_os_yx_isa8_osv8_isv2; return kernel_selector::weights_layout::is_os_yx_isa8_osv8_isv2;
case format::is_os_yx_isa8_osv8_isv4: case format::is_os_yx_isa8_osv8_isv4:
return kernel_selector::weights_layout::is_os_yx_isa8_osv8_isv4; return kernel_selector::weights_layout::is_os_yx_isa8_osv8_isv4;
case format::is_os_yx_osa8_isv16_osv4:
return kernel_selector::weights_layout::is_os_yx_osa8_isv16_osv4;
case format::is_os_yx_isa2_osa8_isv8_osv2: case format::is_os_yx_isa2_osa8_isv8_osv2:
return kernel_selector::weights_layout::is_os_yx_isa2_osa8_isv8_osv2; return kernel_selector::weights_layout::is_os_yx_isa2_osa8_isv8_osv2;
case format::is_os_yx_isa4_osa8_isv8_osv4: case format::is_os_yx_isa4_osa8_isv8_osv4:
@ -812,6 +818,10 @@ cldnn::format::type from_weights_layout(kernel_selector::weights_layout l) {
return cldnn::format::is_os_yx_isv16_osv16; return cldnn::format::is_os_yx_isv16_osv16;
case kernel_selector::weights_layout::is_os_yx_isv16_osv8: case kernel_selector::weights_layout::is_os_yx_isv16_osv8:
return cldnn::format::is_os_yx_isv16_osv8; return cldnn::format::is_os_yx_isv16_osv8;
case kernel_selector::weights_layout::is_os_yx_isv16_osv4:
return cldnn::format::is_os_yx_isv16_osv4;
case kernel_selector::weights_layout::is_os_yx_isv16_osv2:
return cldnn::format::is_os_yx_isv16_osv2;
case kernel_selector::weights_layout::is_os_zyx_isa8_osv8_isv2: case kernel_selector::weights_layout::is_os_zyx_isa8_osv8_isv2:
return cldnn::format::is_os_zyx_isa8_osv8_isv2; return cldnn::format::is_os_zyx_isa8_osv8_isv2;
case kernel_selector::weights_layout::is_os_zyx_isa8_osv8_isv4: case kernel_selector::weights_layout::is_os_zyx_isa8_osv8_isv4:
@ -822,6 +832,8 @@ cldnn::format::type from_weights_layout(kernel_selector::weights_layout l) {
return cldnn::format::is_os_yx_isa8_osv8_isv2; return cldnn::format::is_os_yx_isa8_osv8_isv2;
case kernel_selector::weights_layout::is_os_yx_isa8_osv8_isv4: case kernel_selector::weights_layout::is_os_yx_isa8_osv8_isv4:
return cldnn::format::is_os_yx_isa8_osv8_isv4; return cldnn::format::is_os_yx_isa8_osv8_isv4;
case kernel_selector::weights_layout::is_os_yx_osa8_isv16_osv4:
return cldnn::format::is_os_yx_osa8_isv16_osv4;
case kernel_selector::weights_layout::os_is_yx_isa8_osv8_isv2: case kernel_selector::weights_layout::os_is_yx_isa8_osv8_isv2:
return cldnn::format::os_is_yx_isa8_osv8_isv2; return cldnn::format::os_is_yx_isa8_osv8_isv2;
case kernel_selector::weights_layout::is_os_yx_osa4_isa8_osv8_isv4: case kernel_selector::weights_layout::is_os_yx_osa4_isa8_osv8_isv4:

View File

@ -16,6 +16,18 @@
isv \ isv \
) )
#define GET_FILTER_IS_OS_YX_ISV_OSV_INDEX(prefix, o, i, y, x, osv, isv) \
get_is_os_zyx_isv_osv_index( \
o, i, 0, y, x, \
CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \
1, \
CAT(prefix, _IFM_NUM), \
CAT(prefix, _OFM_NUM), \
osv, \
isv \
)
#define GET_FILTER_OS_IS_ZYX_ISV_OSV_INDEX(prefix, o, i, z, y, x, osv, isv) \ #define GET_FILTER_OS_IS_ZYX_ISV_OSV_INDEX(prefix, o, i, z, y, x, osv, isv) \
get_os_is_zyx_isv_osv_index( \ get_os_is_zyx_isv_osv_index( \
o, i, z, y, x, \ o, i, z, y, x, \
@ -112,6 +124,32 @@ inline uint get_os_is_zyx_isv_osv_index(uint o, uint i, uint z, uint y, uint x,
return output_offset; return output_offset;
} }
inline uint get_is_os_zyx_isv_osv_index(uint o, uint i, uint z, uint y, uint x,
uint x_size, uint y_size, uint z_size, uint i_size, uint o_size, uint osv_size, uint isv_size)
{
const uint isv = i % isv_size;
const uint osv = o % osv_size;
const uint is = i / isv_size;
const uint os = o / osv_size;
const uint x_pitch = osv_size * isv_size;
const uint y_pitch = x_pitch * x_size;
const uint z_pitch = y_pitch * y_size;
const uint os_pitch = z_pitch * z_size;
const uint is_pitch = os_pitch * ((o_size + osv_size - 1) / osv_size);
const uint output_offset =
osv +
isv * osv_size +
x * x_pitch +
y * y_pitch +
z * z_pitch +
os * os_pitch +
is * is_pitch;
return output_offset;
}
inline uint get_os_is_zyx_osv_isv_index(uint o, uint i, uint z, uint y, uint x, inline uint get_os_is_zyx_osv_isv_index(uint o, uint i, uint z, uint y, uint x,
uint x_size, uint y_size, uint z_size, uint i_size, uint o_size, uint osv_size, uint isv_size) uint x_size, uint y_size, uint z_size, uint i_size, uint o_size, uint osv_size, uint isv_size)
{ {
@ -722,6 +760,37 @@ inline uint get_is_os_yx_isa8_osv8_isv4_index(uint o, uint i, uint y, uint x, ui
CAT(prefix, _OFM_NUM), \ CAT(prefix, _OFM_NUM), \
CAT(prefix, _OFFSET)) CAT(prefix, _OFFSET))
inline uint get_is_os_yx_osa8_isv16_osv4_index(uint o, uint i, uint y, uint x, uint size_x,
uint size_y, uint size_ifm, uint size_ofm, uint offset)
{
const uint osv2_idx = o % 4;
const uint isv_idx = i % 16;
const uint osv1_idx = (o / 4) % 8;
const uint os_idx = o / 32;
const uint is_idx = i / 16;
const uint of_32_aligned = ((size_ofm + 31) / 32);
size_t idx = offset +
osv2_idx +
isv_idx * 4 +
osv1_idx * 16 * 4 +
x * 8 * 16 * 4 +
y * size_x * 8 * 16 * 4 +
os_idx * size_y * size_x * 4 * 16 * 8 +
is_idx * of_32_aligned * size_y * size_x * 4 * 16 * 8;
return idx;
}
#define GET_FILTER_IS_OS_YX_OSA8_ISV16_OSV4_INDEX(prefix, o, i, y, x) \
get_is_os_yx_osa8_isv16_osv4_index( \
o, i, y, x, CAT(prefix, _SIZE_X ), \
CAT(prefix, _SIZE_Y), \
CAT(prefix, _IFM_NUM), \
CAT(prefix, _OFM_NUM), \
CAT(prefix, _OFFSET))
inline uint get_os_is_zyx_isa8_osv8_isv4_index(uint o, uint i, uint z, uint y, uint x, inline uint get_os_is_zyx_isa8_osv8_isv4_index(uint o, uint i, uint z, uint y, uint x,
uint size_x, uint size_y, uint size_z, uint size_x, uint size_y, uint size_z,
uint size_ifm, uint size_ofm, uint offset) uint size_ifm, uint size_ofm, uint offset)

View File

@ -320,6 +320,10 @@ inline uint FUNC(get_input_index)(uint g, uint o, uint i, uint z, uint y, uint x
return GET_FILTER_IS_OS_YX_ISV16_OSV16_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE); return GET_FILTER_IS_OS_YX_ISV16_OSV16_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE);
#elif defined INPUT0_LAYOUT_IS_OS_YX_ISV16_OSV8 #elif defined INPUT0_LAYOUT_IS_OS_YX_ISV16_OSV8
return GET_FILTER_IS_OS_YX_ISV16_OSV8_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE); return GET_FILTER_IS_OS_YX_ISV16_OSV8_INDEX(INPUT0, o, i, y, x, SUB_GROUP_SIZE);
#elif defined INPUT0_LAYOUT_IS_OS_YX_ISV16_OSV4
return GET_FILTER_IS_OS_YX_ISV_OSV_INDEX(INPUT0, o, i, y, x, 16, 4);
#elif defined INPUT0_LAYOUT_IS_OS_YX_ISV16_OSV2
return GET_FILTER_IS_OS_YX_ISV_OSV_INDEX(INPUT0, o, i, y, x, 16, 2);
#elif defined INPUT0_LAYOUT_G_OS_IS_ZYX_ISA8_OSV8_ISV2 #elif defined INPUT0_LAYOUT_G_OS_IS_ZYX_ISA8_OSV8_ISV2
return GET_FILTER_G_OS_IS_ZYX_ISA8_OSV8_ISV2_INDEX(INPUT0, g, o, i, z, y, x); return GET_FILTER_G_OS_IS_ZYX_ISA8_OSV8_ISV2_INDEX(INPUT0, g, o, i, z, y, x);
#elif defined INPUT0_LAYOUT_G_OS_IS_ZYX_ISA8_OSV8_ISV4 #elif defined INPUT0_LAYOUT_G_OS_IS_ZYX_ISA8_OSV8_ISV4
@ -340,6 +344,8 @@ inline uint FUNC(get_input_index)(uint g, uint o, uint i, uint z, uint y, uint x
return GET_FILTER_IS_OS_YX_ISA8_OSV8_ISV2_INDEX(INPUT0, o, i, y, x); return GET_FILTER_IS_OS_YX_ISA8_OSV8_ISV2_INDEX(INPUT0, o, i, y, x);
#elif defined INPUT0_LAYOUT_IS_OS_YX_ISA8_OSV8_ISV4 #elif defined INPUT0_LAYOUT_IS_OS_YX_ISA8_OSV8_ISV4
return GET_FILTER_IS_OS_YX_ISA8_OSV8_ISV4_INDEX(INPUT0, o, i, y, x); return GET_FILTER_IS_OS_YX_ISA8_OSV8_ISV4_INDEX(INPUT0, o, i, y, x);
#elif defined INPUT0_LAYOUT_IS_OS_YX_OSA8_ISV16_OSV4
return GET_FILTER_IS_OS_YX_OSA8_ISV16_OSV4_INDEX(INPUT0, o, i, y, x);
#elif defined INPUT0_LAYOUT_OS_IS_OSV32_ISV32_SWIZZLED_BY_4 #elif defined INPUT0_LAYOUT_OS_IS_OSV32_ISV32_SWIZZLED_BY_4
return GET_FILTER_OS_IS_OSV32_ISV32_SWIZZLED_BY_4_INDEX(INPUT0, o, i, y, x); return GET_FILTER_OS_IS_OSV32_ISV32_SWIZZLED_BY_4_INDEX(INPUT0, o, i, y, x);
#elif defined INPUT0_LAYOUT_OS_IS_ZYX_ISV8_OSV16_ISV2 #elif defined INPUT0_LAYOUT_OS_IS_ZYX_ISV8_OSV16_ISV2
@ -577,6 +583,10 @@ inline uint FUNC(get_output_index)(uint g, uint o, uint i, uint z, uint y, uint
return GET_FILTER_IS_OS_YX_ISV16_OSV16_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE); return GET_FILTER_IS_OS_YX_ISV16_OSV16_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE);
#elif defined OUTPUT_LAYOUT_IS_OS_YX_ISV16_OSV8 #elif defined OUTPUT_LAYOUT_IS_OS_YX_ISV16_OSV8
return GET_FILTER_IS_OS_YX_ISV16_OSV8_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE); return GET_FILTER_IS_OS_YX_ISV16_OSV8_INDEX(OUTPUT, o, i, y, x, SUB_GROUP_SIZE);
#elif defined OUTPUT_LAYOUT_IS_OS_YX_ISV16_OSV4
return GET_FILTER_IS_OS_YX_ISV_OSV_INDEX(OUTPUT, o, i, y, x, 16, 4);
#elif defined OUTPUT_LAYOUT_IS_OS_YX_ISV16_OSV2
return GET_FILTER_IS_OS_YX_ISV_OSV_INDEX(OUTPUT, o, i, y, x, 16, 2);
#elif defined OUTPUT_LAYOUT_G_OS_IS_ZYX_ISA8_OSV8_ISV2 #elif defined OUTPUT_LAYOUT_G_OS_IS_ZYX_ISA8_OSV8_ISV2
return GET_FILTER_G_OS_IS_ZYX_ISA8_OSV8_ISV2_INDEX(OUTPUT, g, o, i, z, y, x); return GET_FILTER_G_OS_IS_ZYX_ISA8_OSV8_ISV2_INDEX(OUTPUT, g, o, i, z, y, x);
#elif defined OUTPUT_LAYOUT_G_OS_IS_ZYX_ISA8_OSV8_ISV4 #elif defined OUTPUT_LAYOUT_G_OS_IS_ZYX_ISA8_OSV8_ISV4
@ -597,6 +607,8 @@ inline uint FUNC(get_output_index)(uint g, uint o, uint i, uint z, uint y, uint
return GET_FILTER_IS_OS_YX_ISA8_OSV8_ISV2_INDEX(OUTPUT, o, i, y, x); return GET_FILTER_IS_OS_YX_ISA8_OSV8_ISV2_INDEX(OUTPUT, o, i, y, x);
#elif defined OUTPUT_LAYOUT_IS_OS_YX_ISA8_OSV8_ISV4 #elif defined OUTPUT_LAYOUT_IS_OS_YX_ISA8_OSV8_ISV4
return GET_FILTER_IS_OS_YX_ISA8_OSV8_ISV4_INDEX(OUTPUT, o, i, y, x); return GET_FILTER_IS_OS_YX_ISA8_OSV8_ISV4_INDEX(OUTPUT, o, i, y, x);
#elif defined OUTPUT_LAYOUT_IS_OS_YX_OSA8_ISV16_OSV4
return GET_FILTER_IS_OS_YX_OSA8_ISV16_OSV4_INDEX(OUTPUT, o, i, y, x);
#elif defined OUTPUT_LAYOUT_OS_IS_OSV32_ISV32_SWIZZLED_BY_4 #elif defined OUTPUT_LAYOUT_OS_IS_OSV32_ISV32_SWIZZLED_BY_4
return GET_FILTER_OS_IS_OSV32_ISV32_SWIZZLED_BY_4_INDEX(OUTPUT, o, i, y, x); return GET_FILTER_OS_IS_OSV32_ISV32_SWIZZLED_BY_4_INDEX(OUTPUT, o, i, y, x);
#elif defined OUTPUT_LAYOUT_OS_IS_YX_ISV8_OSV16_ISV2 #elif defined OUTPUT_LAYOUT_OS_IS_YX_ISV8_OSV16_ISV2

View File

@ -354,11 +354,14 @@ std::string toString(WeightsLayout layout) {
case WeightsLayout::is_os_zyx_isv16_osv16: return "IS_OS_ZYX_ISV16_OSV16"; case WeightsLayout::is_os_zyx_isv16_osv16: return "IS_OS_ZYX_ISV16_OSV16";
case WeightsLayout::is_os_yx_isv16_osv16: return "IS_OS_YX_ISV16_OSV16"; case WeightsLayout::is_os_yx_isv16_osv16: return "IS_OS_YX_ISV16_OSV16";
case WeightsLayout::is_os_yx_isv16_osv8: return "IS_OS_YX_ISV16_OSV8"; case WeightsLayout::is_os_yx_isv16_osv8: return "IS_OS_YX_ISV16_OSV8";
case WeightsLayout::is_os_yx_isv16_osv4: return "IS_OS_YX_ISV16_OSV4";
case WeightsLayout::is_os_yx_isv16_osv2: return "IS_OS_YX_ISV16_OSV2";
case WeightsLayout::is_os_zyx_isa8_osv8_isv2: return "IS_OS_ZYX_ISA8_OSV8_ISV2"; case WeightsLayout::is_os_zyx_isa8_osv8_isv2: return "IS_OS_ZYX_ISA8_OSV8_ISV2";
case WeightsLayout::is_os_zyx_isa8_osv8_isv4: return "IS_OS_ZYX_ISA8_OSV8_ISV4"; case WeightsLayout::is_os_zyx_isa8_osv8_isv4: return "IS_OS_ZYX_ISA8_OSV8_ISV4";
case WeightsLayout::os_is_zyx_isa8_osv8_isv2: return "OS_IS_ZYX_ISA8_OSV8_ISV2"; case WeightsLayout::os_is_zyx_isa8_osv8_isv2: return "OS_IS_ZYX_ISA8_OSV8_ISV2";
case WeightsLayout::is_os_yx_isa8_osv8_isv2: return "IS_OS_YX_ISA8_OSV8_ISV2"; case WeightsLayout::is_os_yx_isa8_osv8_isv2: return "IS_OS_YX_ISA8_OSV8_ISV2";
case WeightsLayout::is_os_yx_isa8_osv8_isv4: return "IS_OS_YX_ISA8_OSV8_ISV4"; case WeightsLayout::is_os_yx_isa8_osv8_isv4: return "IS_OS_YX_ISA8_OSV8_ISV4";
case WeightsLayout::is_os_yx_osa8_isv16_osv4: return "IS_OS_YX_OSA8_ISV16_OSV4";
case WeightsLayout::os_is_yx_isa8_osv8_isv2: return "OS_IS_YX_ISA8_OSV8_ISV2"; case WeightsLayout::os_is_yx_isa8_osv8_isv2: return "OS_IS_YX_ISA8_OSV8_ISV2";
case WeightsLayout::os_is_zyx_isv8_osv16_isv2: return "OS_IS_ZYX_ISV8_OSV16_ISV2"; case WeightsLayout::os_is_zyx_isv8_osv16_isv2: return "OS_IS_ZYX_ISV8_OSV16_ISV2";
case WeightsLayout::os_zyxi_osv16: return "OS_ZYXI_OSV16"; case WeightsLayout::os_zyxi_osv16: return "OS_ZYXI_OSV16";

View File

@ -147,6 +147,8 @@ WeightsTensor::WeightsChannelArray WeightsTensor::weightsChannelArray {{
{ WeightsLayout::is_os_zyx_isv16_osv16, { 0, 1, 2, 4, 3, -1 } }, { WeightsLayout::is_os_zyx_isv16_osv16, { 0, 1, 2, 4, 3, -1 } },
{ WeightsLayout::is_os_yx_isv16_osv16, { 0, 1, -1, 3, 2, -1 } }, { WeightsLayout::is_os_yx_isv16_osv16, { 0, 1, -1, 3, 2, -1 } },
{ WeightsLayout::is_os_yx_isv16_osv8, { 0, 1, -1, 3, 2, -1 } }, { WeightsLayout::is_os_yx_isv16_osv8, { 0, 1, -1, 3, 2, -1 } },
{ WeightsLayout::is_os_yx_isv16_osv4, { 0, 1, -1, 3, 2, -1 } },
{ WeightsLayout::is_os_yx_isv16_osv2, { 0, 1, -1, 3, 2, -1 } },
{ WeightsLayout::is_os_yx_isa2_osa8_isv8_osv2, { 0, 1, -1, 3, 2, -1 } }, { WeightsLayout::is_os_yx_isa2_osa8_isv8_osv2, { 0, 1, -1, 3, 2, -1 } },
{ WeightsLayout::is_os_yx_isa4_osa8_isv8_osv4, { 0, 1, -1, 3, 2, -1 } }, { WeightsLayout::is_os_yx_isa4_osa8_isv8_osv4, { 0, 1, -1, 3, 2, -1 } },
{ WeightsLayout::is_os_yx_osa4_isa8_osv8_isv4, { 0, 1, -1, 3, 2, -1 } }, { WeightsLayout::is_os_yx_osa4_isa8_osv8_isv4, { 0, 1, -1, 3, 2, -1 } },
@ -159,6 +161,7 @@ WeightsTensor::WeightsChannelArray WeightsTensor::weightsChannelArray {{
{ WeightsLayout::os_is_yx_isa8_osv8_isv2, { 0, 1, -1, 2, 3, -1 } }, { WeightsLayout::os_is_yx_isa8_osv8_isv2, { 0, 1, -1, 2, 3, -1 } },
{ WeightsLayout::is_os_yx_isa8_osv8_isv2, { 0, 1, -1, 3, 2, -1 } }, { WeightsLayout::is_os_yx_isa8_osv8_isv2, { 0, 1, -1, 3, 2, -1 } },
{ WeightsLayout::is_os_yx_isa8_osv8_isv4, { 0, 1, -1, 3, 2, -1 } }, { WeightsLayout::is_os_yx_isa8_osv8_isv4, { 0, 1, -1, 3, 2, -1 } },
{ WeightsLayout::is_os_yx_osa8_isv16_osv4, { 0, 1, -1, 3, 2, -1 } },
{ WeightsLayout::os_zyxi_osv16, { 1, 2, 3, 0, 4, -1 } }, { WeightsLayout::os_zyxi_osv16, { 1, 2, 3, 0, 4, -1 } },
{ WeightsLayout::os_i_yxs_osv4_yxsv4, { 0, 1, -1, 2, 3, -1 } }, { WeightsLayout::os_i_yxs_osv4_yxsv4, { 0, 1, -1, 2, 3, -1 } },
{ WeightsLayout::os_y_is_x_osv8_isv2, { 0, 2, -1, 1, 3, -1 } }, { WeightsLayout::os_y_is_x_osv8_isv2, { 0, 2, -1, 1, 3, -1 } },
@ -778,6 +781,16 @@ NDims WeightsTensor::GetSimpleDims(const std::vector<size_t>& d, WeightsLayout l
newDims[2] = RoundUp(newDims[2], 16); newDims[2] = RoundUp(newDims[2], 16);
newDims[3] = RoundUp(newDims[3], 8); newDims[3] = RoundUp(newDims[3], 8);
break; break;
case is_os_yx_isv16_osv4:
assert(newDims.size() == 4);
newDims[2] = RoundUp(newDims[2], 16);
newDims[3] = RoundUp(newDims[3], 4);
break;
case is_os_yx_isv16_osv2:
assert(newDims.size() == 4);
newDims[2] = RoundUp(newDims[2], 16);
newDims[3] = RoundUp(newDims[3], 2);
break;
case os_is_yx_isv8_osv16_isv2: case os_is_yx_isv8_osv16_isv2:
assert(newDims.size() == 4); assert(newDims.size() == 4);
newDims[2] = RoundUp(newDims[2], 16); newDims[2] = RoundUp(newDims[2], 16);

View File

@ -100,6 +100,8 @@ enum WeightsLayout {
is_os_zyx_isv16_osv16, is_os_zyx_isv16_osv16,
is_os_yx_isv16_osv16, is_os_yx_isv16_osv16,
is_os_yx_isv16_osv8, is_os_yx_isv16_osv8,
is_os_yx_isv16_osv4,
is_os_yx_isv16_osv2,
os_is_zyx_isv8_osv16_isv2, os_is_zyx_isv8_osv16_isv2,
os_is_yx_isv8_osv16_isv2, os_is_yx_isv8_osv16_isv2,
os_is_yx_isv16_osv16, os_is_yx_isv16_osv16,
@ -153,6 +155,7 @@ enum WeightsLayout {
os_is_yx_isa8_osv8_isv2, os_is_yx_isa8_osv8_isv2,
is_os_yx_isa8_osv8_isv2, is_os_yx_isa8_osv8_isv2,
is_os_yx_isa8_osv8_isv4, is_os_yx_isa8_osv8_isv4,
is_os_yx_osa8_isv16_osv4,
is_os_yx_isa2_osa8_isv8_osv2, is_os_yx_isa2_osa8_isv8_osv2,
g_os_is_yx_osa2_isa8_osv16_isv4, g_os_is_yx_osa2_isa8_osv16_isv4,
g_os_is_yx_osa2_isa8_osv16_isv2, g_os_is_yx_osa2_isa8_osv16_isv2,

View File

@ -129,6 +129,8 @@ static const std::map<format::type, format_traits> format_traits_map {
FMT_TRAITS(is_os_zyx_isv16_osv16, 1, 1, 3, 0, {1, 0, 2, 3, 4}, "iozyx", "oixyz", {{1, 16}, {0, 16}}), FMT_TRAITS(is_os_zyx_isv16_osv16, 1, 1, 3, 0, {1, 0, 2, 3, 4}, "iozyx", "oixyz", {{1, 16}, {0, 16}}),
FMT_TRAITS(is_os_yx_isv16_osv16, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {{1, 16}, {0, 16}}), FMT_TRAITS(is_os_yx_isv16_osv16, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {{1, 16}, {0, 16}}),
FMT_TRAITS(is_os_yx_isv16_osv8, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {{1, 16}, {0, 8}}), FMT_TRAITS(is_os_yx_isv16_osv8, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {{1, 16}, {0, 8}}),
FMT_TRAITS(is_os_yx_isv16_osv4, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {{1, 16}, {0, 4}}),
FMT_TRAITS(is_os_yx_isv16_osv2, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {{1, 16}, {0, 2}}),
FMT_TRAITS(is_os_zyx_isa8_osv8_isv2, 1, 1, 3, 0, {1, 0, 2, 3, 4}, "iozyx", "oixyz", {{1, 8}, {0, 8}, {1, 2}}), FMT_TRAITS(is_os_zyx_isa8_osv8_isv2, 1, 1, 3, 0, {1, 0, 2, 3, 4}, "iozyx", "oixyz", {{1, 8}, {0, 8}, {1, 2}}),
FMT_TRAITS(is_os_zyx_isa8_osv8_isv4, 1, 1, 3, 0, {1, 0, 2, 3, 4}, "iozyx", "oixyz", {{1, 8}, {0, 8}, {1, 4}}), FMT_TRAITS(is_os_zyx_isa8_osv8_isv4, 1, 1, 3, 0, {1, 0, 2, 3, 4}, "iozyx", "oixyz", {{1, 8}, {0, 8}, {1, 4}}),
FMT_TRAITS(os_is_zyx_isa8_osv8_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 8}, {0, 8}, {1, 4}}), FMT_TRAITS(os_is_zyx_isa8_osv8_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 8}, {0, 8}, {1, 4}}),
@ -136,6 +138,7 @@ static const std::map<format::type, format_traits> format_traits_map {
FMT_TRAITS(os_is_zyx_isa8_osv16_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 8}, {0, 16}, {1, 4}}), FMT_TRAITS(os_is_zyx_isa8_osv16_isv4, 1, 1, 3, 0, {0, 1, 2, 3, 4}, "oizyx", "oixyz", {{1, 8}, {0, 16}, {1, 4}}),
FMT_TRAITS(is_os_yx_isa8_osv8_isv2, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {{1, 8}, {0, 8}, {1, 2}}), FMT_TRAITS(is_os_yx_isa8_osv8_isv2, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {{1, 8}, {0, 8}, {1, 2}}),
FMT_TRAITS(is_os_yx_isa8_osv8_isv4, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {{1, 8}, {0, 8}, {1, 4}}), FMT_TRAITS(is_os_yx_isa8_osv8_isv4, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {{1, 8}, {0, 8}, {1, 4}}),
FMT_TRAITS(is_os_yx_osa8_isv16_osv4, 1, 1, 2, 0, {1, 0, 2, 3}, "ioyx", "oixy", {{0, 8}, {1, 16}, {0, 4}}),
FMT_TRAITS(os_is_yx_isa8_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 8}, {0, 8}, {1, 4}}), FMT_TRAITS(os_is_yx_isa8_osv8_isv4, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 8}, {0, 8}, {1, 4}}),
FMT_TRAITS(os_is_yx_isa8_osv8_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 8}, {0, 8}, {1, 2}}), FMT_TRAITS(os_is_yx_isa8_osv8_isv2, 1, 1, 2, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{1, 8}, {0, 8}, {1, 2}}),
FMT_TRAITS(os_is_osv32_isv32_swizzled_by_4, 1, 1, 0, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 32}, {1, 32}}), FMT_TRAITS(os_is_osv32_isv32_swizzled_by_4, 1, 1, 0, 0, {0, 1, 2, 3}, "oiyx", "oixy", {{0, 32}, {1, 32}}),