[GPU] Use b_fs_yx_fsv16 format for OneDNN convolutins in case of FP32 output (#8808)
This commit is contained in:
parent
ab3a892d48
commit
e76fc14ae1
@ -876,15 +876,32 @@ layout layout_optimizer::get_expected_layout(layout const& current_layout,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (use_onednn_impls) {
|
if (use_onednn_impls) {
|
||||||
|
std::function<bool(const program_node&)> has_any_convolutions_below;
|
||||||
|
has_any_convolutions_below = [&](const program_node& node) -> bool {
|
||||||
|
if (node.get_users().empty())
|
||||||
|
return false;
|
||||||
|
for (auto& usr : node.get_users()) {
|
||||||
|
if (usr->is_type<convolution>())
|
||||||
|
return true;
|
||||||
|
return has_any_convolutions_below(*usr);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
/* ***************************** OneDNN impls format selection part ****************************** */
|
/* ***************************** OneDNN impls format selection part ****************************** */
|
||||||
bool valid_grouped = !is_dw && prim->groups > 1 && (ofm_per_group % compute_block == 0 && ifm_per_group % compute_block == 0);
|
bool valid_grouped = !is_dw && prim->groups > 1 && (ofm_per_group % compute_block == 0 && ifm_per_group % compute_block == 0);
|
||||||
if (i8_u8_input) {
|
if (i8_u8_input) {
|
||||||
if ((non_grouped || valid_grouped || valid_int8_dw) && onednn_valid_post_ops && is_2d) {
|
if ((non_grouped || valid_grouped || valid_int8_dw) && onednn_valid_post_ops && is_2d) {
|
||||||
if (input_layout.size.batch[0] % 16 == 0) {
|
if (input_layout.size.batch[0] % 16 == 0) {
|
||||||
expected_format = cldnn::format::bs_fs_yx_bsv32_fsv32;
|
expected_format = cldnn::format::bs_fs_yx_bsv32_fsv32;
|
||||||
|
} else {
|
||||||
|
if (data_type_traits::is_floating_point(output_layout.data_type) &&
|
||||||
|
!has_any_convolutions_below(node)) {
|
||||||
|
expected_format = cldnn::format::b_fs_yx_fsv16;
|
||||||
} else {
|
} else {
|
||||||
expected_format = cldnn::format::b_fs_yx_fsv32;
|
expected_format = cldnn::format::b_fs_yx_fsv32;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
} else if ((_optimization_attributes.b_fs_yx_fsv16_network &&
|
} else if ((_optimization_attributes.b_fs_yx_fsv16_network &&
|
||||||
convolution_b_fs_yx_fsv16_opt(input_layout, output_layout, weights_layout, prim)) && is_2d) {
|
convolution_b_fs_yx_fsv16_opt(input_layout, output_layout, weights_layout, prim)) && is_2d) {
|
||||||
// TODO: optimize clDNN kernels for good support of b_fs_yx_fsv32 format
|
// TODO: optimize clDNN kernels for good support of b_fs_yx_fsv32 format
|
||||||
|
Loading…
Reference in New Issue
Block a user