[GPU] Change FQ output for first Convolution (#9200)

* update onednn_gpu

* [GPU] Add bs_fs_yx_bsv8_fsv4 format

Co-authored-by: Kim,SungEun <sungeun.kim@intel.com>
This commit is contained in:
Sergey Shlyapnikov 2021-12-15 13:15:13 +03:00 committed by GitHub
parent b492b59136
commit 1177d2b282
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 102 additions and 26 deletions

View File

@ -85,6 +85,7 @@ struct format {
bs_fs_zyx_bsv16_fsv16, ///< format used for 3D blocked convolution (batch and features blocked by 16)
bs_fs_yx_bsv16_fsv16, ///< format used for 2D blocked convolution (batch and features blocked by 16)
bs_fs_yx_bsv4_fsv4, ///< format used for 2D blocked convolution (batch and features blocked by 4)
bs_fs_yx_bsv8_fsv4, ///< format used for 2D blocked convolution (batch and features blocked by 8 and 4)
bs_fs_yx_bsv4_fsv2, ///< format used for 2D blocked convolution (batch blocked by 4, features blocked by 2)
bs_fs_zyx_bsv4_fsv4, ///< format used for 3D blocked convolution (batch and features blocked by 4)
bs_fs_zyx_bsv4_fsv2, ///< format used for 3D blocked convolution (batch blocked by 4, features blocked by 2)
@ -255,6 +256,7 @@ struct format {
{ bs_fs_zyx_bsv16_fsv16, { 1, 1, 3, 0, "bfzyx", "bfxyz", {{0, 16 }, {1, 16}}}},
{ bs_fs_yx_bsv16_fsv16, { 1, 1, 2, 0, "bfyx", "bfxy?", {{0, 16 }, {1, 16}}}},
{ bs_fs_yx_bsv4_fsv4, { 1, 1, 2, 0, "bfyx", "bfxy?", {{0, 4 }, {1, 4}}}},
{ bs_fs_yx_bsv8_fsv4, { 1, 1, 2, 0, "bfyx", "bfxy?", {{0, 8 }, {1, 4}}}},
{ bs_fs_yx_bsv4_fsv2, { 1, 1, 2, 0, "bfyx", "bfxy?", {{0, 4 }, {1, 2}}}},
{ bs_fs_zyx_bsv4_fsv4, { 1, 1, 3, 0, "bfzyx", "bfxyz", {{0, 4 }, {1, 4}}}},
{ bs_fs_zyx_bsv4_fsv2, { 1, 1, 3, 0, "bfzyx", "bfxyz", {{0, 4 }, {1, 2}}}},

View File

@ -29,6 +29,7 @@ DataTensor::DataChannelArray DataTensor::dataChannelArray {{
{ DataLayout::bs_fs_zyx_bsv16_fsv16, { 0, 1, 2, -1, 3, 4 } },
{ DataLayout::bs_fs_yx_bsv16_fsv16, { 0, 1, -1, -1, 2, 3 } },
{ DataLayout::bs_fs_yx_bsv4_fsv4, { 0, 1, -1, -1, 2, 3 } },
{ DataLayout::bs_fs_yx_bsv8_fsv4, { 0, 1, -1, -1, 2, 3 } },
{ DataLayout::bs_fs_yx_bsv4_fsv2, { 0, 1, -1, -1, 2, 3 } },
{ DataLayout::bs_fs_yx_bsv32_fsv32, { 0, 1, -1, -1, 2, 3 } },
{ DataLayout::bs_fs_yx_bsv32_fsv16, { 0, 1, -1, -1, 2, 3 } },
@ -206,6 +207,11 @@ NDims DataTensor::GetSimpleDims(const std::vector<size_t>& d, DataLayout l) {
newDims[2] = RoundUp(newDims[2], 4);
newDims[3] = RoundUp(newDims[3], 4);
break;
case bs_fs_yx_bsv8_fsv4:
assert(newDims.size() == 4);
newDims[2] = RoundUp(newDims[2], 4);
newDims[3] = RoundUp(newDims[3], 8);
break;
case bs_fs_yx_bsv4_fsv2:
assert(newDims.size() == 4);
newDims[2] = RoundUp(newDims[2], 2);

View File

@ -39,6 +39,7 @@ enum DataLayout {
bs_fs_yx_bsv16_fsv16, // batch, feature, 2D spatial. Blocks of 16 batch and channels
bs_fs_zyx_bsv16_fsv16, // batch, feature, 3D spatial. Blocks of 16 batch and channels
bs_fs_yx_bsv4_fsv4, // batch, feature, 2D spatial. Blocks of 4 batch and 4 channels
bs_fs_yx_bsv8_fsv4, // batch, feature, 2D spatial. Blocks of 8 batch and 4 channels
bs_fs_yx_bsv4_fsv2, // batch, feature, 2D spatial. Blocks of 4 batch and 2 channels
bs_fs_yx_bsv32_fsv32, // batch, feature, 2D spatial. Blocks of 32 batch and 32 channels
bs_fs_yx_bsv32_fsv16, // batch, feature, 2D spatial. Blocks of 32 batch and 16 channels

View File

@ -506,6 +506,22 @@ inline uint get_bs_fs_zyx_bsv_fsv_index(uint b, uint f, uint z, uint y, uint x,
CAT(prefix, _PAD_BEFORE_SIZE_X), \
CAT(prefix, _PAD_AFTER_SIZE_X), 4, 4)
#define GET_DATA_BS_FS_YX_BSV8_FSV4_INDEX(prefix, b, f, y, x) \
get_bs_fs_zyx_bsv_fsv_index( \
b, f, 0, y, x, \
CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \
CAT(prefix, _SIZE_Z), \
CAT(prefix, _FEATURE_NUM), \
CAT(prefix, _PAD_BEFORE_FEATURE_NUM), \
CAT(prefix, _PAD_AFTER_FEATURE_NUM), \
CAT(prefix, _PAD_BEFORE_SIZE_Z), \
CAT(prefix, _PAD_AFTER_SIZE_Z), \
CAT(prefix, _PAD_BEFORE_SIZE_Y), \
CAT(prefix, _PAD_AFTER_SIZE_Y), \
CAT(prefix, _PAD_BEFORE_SIZE_X), \
CAT(prefix, _PAD_AFTER_SIZE_X), 8, 4)
#define GET_DATA_BS_FS_YX_BSV4_FSV2_INDEX(prefix, b, f, y, x) \
get_bs_fs_zyx_bsv_fsv_index( \
b, f, 0, y, x, \
@ -605,6 +621,23 @@ inline uint get_bs_fs_zyx_bsv_fsv_index(uint b, uint f, uint z, uint y, uint x,
CAT(prefix, _PAD_BEFORE_SIZE_X), \
CAT(prefix, _PAD_AFTER_SIZE_X), 4, 4)
#define GET_DATA_BS_FS_YX_BSV8_FSV4_INDEX_SAFE(prefix, b, f, y, x) \
get_bs_fs_zyx_bsv_fsv_index_safe( \
b, f, 0, y, x, \
CAT(prefix, _SIZE_X), \
CAT(prefix, _SIZE_Y), \
CAT(prefix, _SIZE_Z), \
CAT(prefix, _FEATURE_NUM), \
CAT(prefix, _BATCH_NUM), \
CAT(prefix, _PAD_BEFORE_FEATURE_NUM), \
CAT(prefix, _PAD_AFTER_FEATURE_NUM), \
CAT(prefix, _PAD_BEFORE_SIZE_Z), \
CAT(prefix, _PAD_AFTER_SIZE_Z), \
CAT(prefix, _PAD_BEFORE_SIZE_Y), \
CAT(prefix, _PAD_AFTER_SIZE_Y), \
CAT(prefix, _PAD_BEFORE_SIZE_X), \
CAT(prefix, _PAD_AFTER_SIZE_X), 8, 4)
#define GET_DATA_BS_FS_YX_BSV4_FSV2_INDEX_SAFE(prefix, b, f, y, x) \
get_bs_fs_zyx_bsv_fsv_index_safe( \
b, f, 0, y, x, \

View File

@ -334,6 +334,7 @@ JitDefinitions DataTensorJitConstant::GetDefinitions() const {
layout == DataLayout::fs_b_yx_fsv32 ||
layout == DataLayout::bs_fs_yx_bsv16_fsv16 ||
layout == DataLayout::bs_fs_yx_bsv4_fsv4 ||
layout == DataLayout::bs_fs_yx_bsv8_fsv4 ||
layout == DataLayout::bs_fs_yx_bsv4_fsv2 ||
layout == DataLayout::bs_fs_yx_bsv32_fsv16 ||
layout == DataLayout::bs_fs_yx_bsv32_fsv32) {
@ -346,6 +347,7 @@ JitDefinitions DataTensorJitConstant::GetDefinitions() const {
layout == DataLayout::bs_fs_yx_bsv32_fsv32 ||
layout == DataLayout::bs_fs_yx_bsv32_fsv16 ||
layout == DataLayout::bs_fs_yx_bsv4_fsv4 ||
layout == DataLayout::bs_fs_yx_bsv8_fsv4 ||
layout == DataLayout::bs_fs_yx_bsv4_fsv2 ||
layout == DataLayout::bs_fs_yx_bsv16_fsv16)
safe_index_func_val = "GET_DATA_" + layout_str + "_INDEX_SAFE(" + _name + ", b, f, y, x)";

View File

@ -105,6 +105,7 @@ std::string toString(DataLayout l) {
case kernel_selector::DataLayout::bs_fs_yx_bsv16_fsv16: return "BS_FS_YX_BSV16_FSV16";
case kernel_selector::DataLayout::bs_fs_zyx_bsv16_fsv16: return "BS_FS_ZYX_BSV16_FSV16";
case kernel_selector::DataLayout::bs_fs_yx_bsv4_fsv4: return "BS_FS_YX_BSV4_FSV4";
case kernel_selector::DataLayout::bs_fs_yx_bsv8_fsv4: return "BS_FS_YX_BSV8_FSV4";
case kernel_selector::DataLayout::bs_fs_yx_bsv4_fsv2: return "BS_FS_YX_BSV4_FSV2";
case kernel_selector::DataLayout::bs_fs_yx_bsv32_fsv32: return "BS_FS_YX_BSV32_FSV32";
case kernel_selector::DataLayout::bs_fs_yx_bsv32_fsv16: return "BS_FS_YX_BSV32_FSV16";

View File

@ -225,6 +225,11 @@ attach_convolution_impl::attach_convolution_impl() {
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv4_fsv2),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv4_fsv2),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv4_fsv2),

View File

@ -214,6 +214,13 @@ attach_eltwise_impl::attach_eltwise_impl() {
std::make_tuple(data_types::i32, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::i64, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::i32, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::i64, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv4_fsv2),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv4_fsv2),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv4_fsv2),

View File

@ -119,6 +119,11 @@ attach_concatenation_onednn::attach_concatenation_onednn() {
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv8_fsv4),
});
}

View File

@ -256,6 +256,11 @@ attach_convolution_onednn::attach_convolution_onednn() {
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv4_fsv2),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv4_fsv2),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv4_fsv2),

View File

@ -199,6 +199,11 @@ attach_deconvolution_onednn::attach_deconvolution_onednn() {
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv8_fsv4),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv4_fsv2),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv4_fsv2),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv4_fsv2),

View File

@ -91,6 +91,7 @@ dnnl::memory::format_tag convert_data_format(cldnn::format fmt) {
case cldnn::format::bs_fs_yx_bsv16_fsv16: return dnnl::memory::format_tag::NChw16n16c;
case cldnn::format::bs_fs_yx_bsv32_fsv32: return dnnl::memory::format_tag::NChw32n32c;
case cldnn::format::bs_fs_yx_bsv4_fsv4: return dnnl::memory::format_tag::ABcd4a4b;
case cldnn::format::bs_fs_yx_bsv8_fsv4: return dnnl::memory::format_tag::ABcd8a4b;
case cldnn::format::bs_fs_yx_bsv4_fsv2: return dnnl::memory::format_tag::ABcd4a2b;
case cldnn::format::bs_fs_yx_bsv32_fsv16: return dnnl::memory::format_tag::NChw32n16c;
case cldnn::format::bs_fs_zyx_bsv16_fsv16: return dnnl::memory::format_tag::NCdhw16n16c;

View File

@ -97,6 +97,8 @@ inline std::string fmt_to_str(format fmt) {
return "bs_fs_yx_bsv4_fsv2";
case format::bs_fs_yx_bsv4_fsv4:
return "bs_fs_yx_bsv4_fsv4";
case format::bs_fs_yx_bsv8_fsv4:
return "bs_fs_yx_bsv8_fsv4";
case format::bs_fs_yx_bsv32_fsv32:
return "bs_fs_yx_bsv32_fsv32";
case format::b_fs_zyx_fsv16:

View File

@ -136,6 +136,8 @@ kernel_selector::data_layout to_data_layout(format f) {
return kernel_selector::data_layout::bs_fs_yx_bsv32_fsv16;
case format::bs_fs_yx_bsv4_fsv4:
return kernel_selector::data_layout::bs_fs_yx_bsv4_fsv4;
case format::bs_fs_yx_bsv8_fsv4:
return kernel_selector::data_layout::bs_fs_yx_bsv8_fsv4;
case format::bs_fs_yx_bsv4_fsv2:
return kernel_selector::data_layout::bs_fs_yx_bsv4_fsv2;
case format::bs_fs_yx_bsv32_fsv32:
@ -193,6 +195,8 @@ cldnn::format from_data_layout(kernel_selector::data_layout l) {
return cldnn::format::bs_fs_yx_bsv4_fsv2;
case kernel_selector::data_layout::bs_fs_yx_bsv4_fsv4:
return cldnn::format::bs_fs_yx_bsv4_fsv4;
case kernel_selector::data_layout::bs_fs_yx_bsv8_fsv4:
return cldnn::format::bs_fs_yx_bsv8_fsv4;
case kernel_selector::data_layout::bs_fs_yx_bsv32_fsv32:
return cldnn::format::bs_fs_yx_bsv32_fsv32;
case kernel_selector::data_layout::nv12:

View File

@ -284,10 +284,11 @@ bool layout_optimizer::can_fuse_reorder(program_node& prev, program_node& next,
return true;
if (next.is_type<convolution>() &&
(fmt_prev == format::b_fs_yx_fsv4 || fmt_prev == format::bs_fs_yx_bsv4_fsv4) &&
(fmt_prev == format::b_fs_yx_fsv4 || fmt_prev == format::bs_fs_yx_bsv4_fsv4 || fmt_prev == format::bs_fs_yx_bsv8_fsv4) &&
((fmt_next == format::b_fs_yx_fsv32 && (prev_output_layout.size.feature[0] == 3 || prev_output_layout.size.feature[0] == 4)) ||
(fmt_next == format::bs_fs_yx_bsv32_fsv32 && (prev_output_layout.size.feature[0] == 3 || prev_output_layout.size.feature[0] == 4)) ||
(fmt_next == format::bs_fs_yx_bsv4_fsv4 && (prev_output_layout.size.feature[0] == 3 || prev_output_layout.size.feature[0] == 4)) ||
(fmt_next == format::bs_fs_yx_bsv8_fsv4 && (prev_output_layout.size.feature[0] == 3 || prev_output_layout.size.feature[0] == 4)) ||
(fmt_next == format::b_fs_yx_fsv16 && next_output_layout.size.feature[0] >= 16 &&
(prev_output_layout.size.feature[0] == 3 || (prev_output_layout.size.feature[0] == 4 && (prev_dt == data_types::u8 || prev_dt == data_types::i8))))))
return true;
@ -1269,6 +1270,7 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format
format::bs_fs_yx_bsv32_fsv16,
format::bs_fs_yx_bsv32_fsv32,
format::bs_fs_yx_bsv4_fsv4,
format::bs_fs_yx_bsv8_fsv4,
format::bs_fs_yx_bsv4_fsv2,
format::bs_fs_zyx_bsv4_fsv4,
format::bs_fs_zyx_bsv4_fsv2,
@ -1463,7 +1465,7 @@ format layout_optimizer::get_preferred_format(program_node& node) {
if (data_type_traits::is_floating_point(conv.get_output_layout().data_type) || ws.spatial[0] != 7 || conv.get_primitive()->groups > 1)
expected = format::bfyx;
else
expected = format::bs_fs_yx_bsv4_fsv4;
expected = format::bs_fs_yx_bsv8_fsv4;
auto conv_output_layout = conv.get_output_layout();
auto weights_layout = conv.weights(0).get_output_layout();

View File

@ -139,30 +139,25 @@ std::pair<bool, bool> program_helpers::are_layouts_identical(layout const& l1, l
return {false, false};
if (l1.get_linear_size() != l2.get_linear_size())
return {false, false};
if ((l1.format == format::b_fs_yx_fsv4 && l2.format != format::b_fs_yx_fsv4) ||
(l2.format == format::b_fs_yx_fsv4 && l1.format != format::b_fs_yx_fsv4) ||
(l1.format == format::fs_b_yx_fsv32 && l2.format != format::fs_b_yx_fsv32) ||
(l2.format == format::fs_b_yx_fsv32 && l1.format != format::fs_b_yx_fsv32) ||
(l1.format == format::b_fs_yx_fsv16 && l2.format != format::b_fs_yx_fsv16) ||
(l2.format == format::b_fs_yx_fsv16 && l1.format != format::b_fs_yx_fsv16) ||
(l1.format == format::b_fs_yx_fsv32 && l2.format != format::b_fs_yx_fsv32) ||
(l2.format == format::b_fs_yx_fsv32 && l1.format != format::b_fs_yx_fsv32) ||
(l1.format == format::b_fs_zyx_fsv32 && l2.format != format::b_fs_zyx_fsv32) ||
(l2.format == format::b_fs_zyx_fsv32 && l1.format != format::b_fs_zyx_fsv32) ||
(l1.format == format::b_fs_zyx_fsv16 && l2.format != format::b_fs_zyx_fsv16) ||
(l2.format == format::b_fs_zyx_fsv16 && l1.format != format::b_fs_zyx_fsv16) ||
(l1.format == format::bs_fs_yx_bsv4_fsv4 && l2.format != format::bs_fs_yx_bsv4_fsv4) ||
(l2.format == format::bs_fs_yx_bsv4_fsv4 && l1.format != format::bs_fs_yx_bsv4_fsv4) ||
(l1.format == format::bs_fs_yx_bsv4_fsv2 && l2.format != format::bs_fs_yx_bsv4_fsv2) ||
(l2.format == format::bs_fs_yx_bsv4_fsv2 && l1.format != format::bs_fs_yx_bsv4_fsv2) ||
(l1.format == format::bs_fs_yx_bsv32_fsv16 && l2.format != format::bs_fs_yx_bsv32_fsv16) ||
(l2.format == format::bs_fs_yx_bsv32_fsv16 && l1.format != format::bs_fs_yx_bsv32_fsv16) ||
(l1.format == format::bs_fs_yx_bsv32_fsv32 && l2.format != format::bs_fs_yx_bsv32_fsv32) ||
(l2.format == format::bs_fs_yx_bsv32_fsv32 && l1.format != format::bs_fs_yx_bsv32_fsv32) ||
(l1.format == format::bs_fs_yx_bsv16_fsv16 && l2.format != format::bs_fs_yx_bsv16_fsv16) ||
(l2.format == format::bs_fs_yx_bsv16_fsv16 && l1.format != format::bs_fs_yx_bsv16_fsv16) ||
(l1.format == format::bs_fs_zyx_bsv16_fsv16 && l2.format != format::bs_fs_zyx_bsv16_fsv16) ||
(l2.format == format::bs_fs_zyx_bsv16_fsv16 && l1.format != format::bs_fs_zyx_bsv16_fsv16))
auto check_format = [&l1, &l2](cldnn::format format) {
return (l1.format == format && l2.format != format) ||
(l2.format == format && l1.format != format);
};
if (check_format(format::b_fs_yx_fsv4) ||
check_format(format::fs_b_yx_fsv32) ||
check_format(format::b_fs_yx_fsv16) ||
check_format(format::b_fs_yx_fsv32) ||
check_format(format::b_fs_zyx_fsv32) ||
check_format(format::b_fs_zyx_fsv16) ||
check_format(format::bs_fs_yx_bsv4_fsv4) ||
check_format(format::bs_fs_yx_bsv8_fsv4) ||
check_format(format::bs_fs_yx_bsv4_fsv2) ||
check_format(format::bs_fs_yx_bsv32_fsv16) ||
check_format(format::bs_fs_yx_bsv32_fsv32) ||
check_format(format::bs_fs_yx_bsv16_fsv16) ||
check_format(format::bs_fs_zyx_bsv16_fsv16))
return {false, false};
auto l1_pitch = l1.get_pitches();