[GPU] Use b_fs_yx_fsv16 format for OneDNN convolutins in case of FP32 output (#8808)

This commit is contained in:
Sergey Shlyapnikov 2021-11-27 15:08:06 +03:00 committed by GitHub
parent ab3a892d48
commit e76fc14ae1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -876,6 +876,18 @@ layout layout_optimizer::get_expected_layout(layout const& current_layout,
}
if (use_onednn_impls) {
std::function<bool(const program_node&)> has_any_convolutions_below;
has_any_convolutions_below = [&](const program_node& node) -> bool {
if (node.get_users().empty())
return false;
for (auto& usr : node.get_users()) {
if (usr->is_type<convolution>())
return true;
return has_any_convolutions_below(*usr);
}
return false;
};
/* ***************************** OneDNN impls format selection part ****************************** */
bool valid_grouped = !is_dw && prim->groups > 1 && (ofm_per_group % compute_block == 0 && ifm_per_group % compute_block == 0);
if (i8_u8_input) {
@ -883,7 +895,12 @@ layout layout_optimizer::get_expected_layout(layout const& current_layout,
if (input_layout.size.batch[0] % 16 == 0) {
expected_format = cldnn::format::bs_fs_yx_bsv32_fsv32;
} else {
expected_format = cldnn::format::b_fs_yx_fsv32;
if (data_type_traits::is_floating_point(output_layout.data_type) &&
!has_any_convolutions_below(node)) {
expected_format = cldnn::format::b_fs_yx_fsv16;
} else {
expected_format = cldnn::format::b_fs_yx_fsv32;
}
}
} else if ((_optimization_attributes.b_fs_yx_fsv16_network &&
convolution_b_fs_yx_fsv16_opt(input_layout, output_layout, weights_layout, prim)) && is_2d) {