[GPU] Use b_fs_yx_fsv16 format for OneDNN convolutins in case of FP32 output (#8808)

This commit is contained in:
Sergey Shlyapnikov 2021-11-27 15:08:06 +03:00 committed by GitHub
parent ab3a892d48
commit e76fc14ae1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -876,15 +876,32 @@ layout layout_optimizer::get_expected_layout(layout const& current_layout,
} }
if (use_onednn_impls) { if (use_onednn_impls) {
std::function<bool(const program_node&)> has_any_convolutions_below;
has_any_convolutions_below = [&](const program_node& node) -> bool {
if (node.get_users().empty())
return false;
for (auto& usr : node.get_users()) {
if (usr->is_type<convolution>())
return true;
return has_any_convolutions_below(*usr);
}
return false;
};
/* ***************************** OneDNN impls format selection part ****************************** */ /* ***************************** OneDNN impls format selection part ****************************** */
bool valid_grouped = !is_dw && prim->groups > 1 && (ofm_per_group % compute_block == 0 && ifm_per_group % compute_block == 0); bool valid_grouped = !is_dw && prim->groups > 1 && (ofm_per_group % compute_block == 0 && ifm_per_group % compute_block == 0);
if (i8_u8_input) { if (i8_u8_input) {
if ((non_grouped || valid_grouped || valid_int8_dw) && onednn_valid_post_ops && is_2d) { if ((non_grouped || valid_grouped || valid_int8_dw) && onednn_valid_post_ops && is_2d) {
if (input_layout.size.batch[0] % 16 == 0) { if (input_layout.size.batch[0] % 16 == 0) {
expected_format = cldnn::format::bs_fs_yx_bsv32_fsv32; expected_format = cldnn::format::bs_fs_yx_bsv32_fsv32;
} else {
if (data_type_traits::is_floating_point(output_layout.data_type) &&
!has_any_convolutions_below(node)) {
expected_format = cldnn::format::b_fs_yx_fsv16;
} else { } else {
expected_format = cldnn::format::b_fs_yx_fsv32; expected_format = cldnn::format::b_fs_yx_fsv32;
} }
}
} else if ((_optimization_attributes.b_fs_yx_fsv16_network && } else if ((_optimization_attributes.b_fs_yx_fsv16_network &&
convolution_b_fs_yx_fsv16_opt(input_layout, output_layout, weights_layout, prim)) && is_2d) { convolution_b_fs_yx_fsv16_opt(input_layout, output_layout, weights_layout, prim)) && is_2d) {
// TODO: optimize clDNN kernels for good support of b_fs_yx_fsv32 format // TODO: optimize clDNN kernels for good support of b_fs_yx_fsv32 format