[GPU] set b_fs_yx_fsv32 for first conv and dt is U8. (#14454)
* set b_fs_yx_fsv32 for first conv and dt is U8. * check input format is bfyx.
This commit is contained in:
parent
0602b852eb
commit
b799e3eb91
@ -1832,6 +1832,15 @@ void layout_optimizer::select_preferred_formats_for_onednn(program_node& node, d
|
|||||||
node.set_preferred_input_fmt(idx, src_fmt);
|
node.set_preferred_input_fmt(idx, src_fmt);
|
||||||
|
|
||||||
auto dst_fmt = onednn::find_data_format(prim_desc.dst_desc());
|
auto dst_fmt = onednn::find_data_format(prim_desc.dst_desc());
|
||||||
|
// Errata: Best impl for shallow input conv with zero-point ops is ocl:xe_lp.
|
||||||
|
if (node.is_type<convolution>() && src_fmt == format::bfyx) {
|
||||||
|
auto& conv = node.as<convolution>();
|
||||||
|
if (conv.get_input_layouts()[0].feature() <= 8 && conv.activations_zero_points_term() &&
|
||||||
|
conv.get_input_layouts()[0].data_type == data_types::u8 && conv.get_output_layout().data_type == data_types::u8) {
|
||||||
|
dst_fmt = format::b_fs_yx_fsv32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (node.get_preferred_output_fmt() == format::any) {
|
if (node.get_preferred_output_fmt() == format::any) {
|
||||||
for (size_t usr = 0 ; usr < node.get_users().size() ; usr++)
|
for (size_t usr = 0 ; usr < node.get_users().size() ; usr++)
|
||||||
node.set_preferred_output_fmt(usr, dst_fmt);
|
node.set_preferred_output_fmt(usr, dst_fmt);
|
||||||
|
Loading…
Reference in New Issue
Block a user