[GPU] set bfyx to input format for shallow conv (#13614)

* set bfyx to input format for shallow conv
This commit is contained in:
Sungeun Kim 2022-10-31 10:22:39 +09:00 committed by GitHub
parent 454bc61018
commit eeabb86b80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1746,7 +1746,7 @@ void layout_optimizer::select_preferred_formats_for_onednn(program_node& node, d
// Conv or deconv gets a preferred format for its data input based on source memory description
// But an input format for fused post-ops should be same with an output format of conv/deconv
size_t prim_input(0);
size_t prim_input(-1);
if (node.is_type<convolution>())
prim_input = node.get_dependency_index(node.as<convolution>().input());
if (node.is_type<deconvolution>())
@ -1759,6 +1759,12 @@ void layout_optimizer::select_preferred_formats_for_onednn(program_node& node, d
else // Dep for fused post ops
src_fmt = onednn::find_data_format(prim_desc.dst_desc());
// WA: shallow convolution needs to set input format by bfyx.
// onednn recommended byxf for input format. It will insert reorder before shallow conv.
if (node.is_type<convolution>() && node.get_input_layouts()[0].feature() == 3) {
src_fmt = format::get_default_format(node.get_input_layouts()[0].get_rank(), false, false);
}
node.set_preferred_input_fmt(idx, src_fmt);
auto dst_fmt = onednn::find_data_format(prim_desc.dst_desc());