removed case to choose onednn impl for deconv (#17108)

- in_dt(f16) wei_dt(f16) out_dt(f32)
This commit is contained in:
Sungeun Kim
2023-04-26 13:20:11 +09:00
committed by GitHub
parent dabd5ee412
commit 3c485feea8
2 changed files with 22 additions and 4 deletions

View File

@@ -191,6 +191,7 @@ public:
bool are_layouts_suitable_for_onednn(program_node& node);
static bool onednn_check_data_types_for_pooling(data_types in_dt, data_types out_dt);
static bool onednn_check_data_types_for_convolution(data_types in_dt, data_types wei_dt, data_types out_dt);
static bool onednn_check_data_types_for_deconvolution(data_types in_dt, data_types wei_dt, data_types out_dt);
static bool onednn_check_data_types_for_fc_gemm(data_types in_dt, data_types wei_dt, data_types out_dt);
static bool onednn_check_preferred_impl_type_of_users(program_node& node);
bool is_primitive_implemented_for_onednn(program_node& node);

View File

@@ -106,6 +106,22 @@ bool layout_optimizer::onednn_check_data_types_for_convolution(data_types in_dt,
return false;
}
// almost same with onednn_check_data_types_for_convolution.
// removed case
// - in_dt(f16) wei_dt(f16) out_dt(f32)
bool layout_optimizer::onednn_check_data_types_for_deconvolution(data_types in_dt, data_types wei_dt, data_types out_dt) {
if ((in_dt == data_types::f16 && wei_dt == data_types::f16) &&
(out_dt == data_types::f16 || out_dt == data_types::i8 || out_dt == data_types::u8))
return true;
if ((in_dt == data_types::i8 || in_dt == data_types::u8) && wei_dt == data_types::i8 &&
(out_dt == data_types::f32 || out_dt == data_types::i32 || out_dt == data_types::f16 || out_dt == data_types::i8 || out_dt == data_types::u8))
return true;
if ((in_dt == data_types::f32 && wei_dt == data_types::f32) &&
(out_dt == data_types::i8 || out_dt == data_types::u8))
return true;
return false;
}
bool layout_optimizer::onednn_check_data_types_for_fc_gemm(data_types in_dt, data_types wei_dt, data_types out_dt) {
if ((in_dt == data_types::f16 && wei_dt == data_types::f16) &&
(out_dt == data_types::f16 || out_dt == data_types::f32 || out_dt == data_types::i8))
@@ -1231,11 +1247,12 @@ bool layout_optimizer::are_data_types_suitable_for_onednn(program_node& node) {
if (node.is_type<pooling>()) {
return onednn_check_data_types_for_pooling(in_dt, out_dt);
} else if (node.is_type<convolution>() || node.is_type<deconvolution>()) {
bool is_conv = node.is_type<convolution>();
auto wei_dt = is_conv ? node.as<convolution>().weights().get_output_layout().data_type :
node.as<deconvolution>().weights().get_output_layout().data_type;
} else if (node.is_type<convolution>()) {
auto wei_dt = node.as<convolution>().weights().get_output_layout().data_type;
return onednn_check_data_types_for_convolution(in_dt, wei_dt, out_dt);
} else if (node.is_type<deconvolution>()) {
auto wei_dt = node.as<deconvolution>().weights().get_output_layout().data_type;
return onednn_check_data_types_for_deconvolution(in_dt, wei_dt, out_dt);
} else if (node.is_type<fully_connected>() || node.is_type<gemm>()) {
bool is_fc = node.is_type<fully_connected>();
auto wei_dt = is_fc ? node.as<fully_connected>().weights().get_output_layout().data_type :