[GPU] set bfyx to input format for shallow conv (#13614)
* set bfyx to input format for shallow conv
This commit is contained in:
parent
454bc61018
commit
eeabb86b80
@ -1746,7 +1746,7 @@ void layout_optimizer::select_preferred_formats_for_onednn(program_node& node, d
|
||||
|
||||
// Conv or deconv gets a preferred format for its data input based on source memory description
|
||||
// But an input format for fused post-ops should be same with an output format of conv/deconv
|
||||
size_t prim_input(0);
|
||||
size_t prim_input(-1);
|
||||
if (node.is_type<convolution>())
|
||||
prim_input = node.get_dependency_index(node.as<convolution>().input());
|
||||
if (node.is_type<deconvolution>())
|
||||
@ -1759,6 +1759,12 @@ void layout_optimizer::select_preferred_formats_for_onednn(program_node& node, d
|
||||
else // Dep for fused post ops
|
||||
src_fmt = onednn::find_data_format(prim_desc.dst_desc());
|
||||
|
||||
// WA: shallow convolution needs to set input format by bfyx.
|
||||
// onednn recommended byxf for input format. It will insert reorder before shallow conv.
|
||||
if (node.is_type<convolution>() && node.get_input_layouts()[0].feature() == 3) {
|
||||
src_fmt = format::get_default_format(node.get_input_layouts()[0].get_rank(), false, false);
|
||||
}
|
||||
|
||||
node.set_preferred_input_fmt(idx, src_fmt);
|
||||
|
||||
auto dst_fmt = onednn::find_data_format(prim_desc.dst_desc());
|
||||
|
Loading…
Reference in New Issue
Block a user