removed case to choose onednn impl for deconv (#17108)
- in_dt(f16) wei_dt(f16) out_dt(f32)
This commit is contained in:
@@ -191,6 +191,7 @@ public:
|
||||
bool are_layouts_suitable_for_onednn(program_node& node);
|
||||
static bool onednn_check_data_types_for_pooling(data_types in_dt, data_types out_dt);
|
||||
static bool onednn_check_data_types_for_convolution(data_types in_dt, data_types wei_dt, data_types out_dt);
|
||||
static bool onednn_check_data_types_for_deconvolution(data_types in_dt, data_types wei_dt, data_types out_dt);
|
||||
static bool onednn_check_data_types_for_fc_gemm(data_types in_dt, data_types wei_dt, data_types out_dt);
|
||||
static bool onednn_check_preferred_impl_type_of_users(program_node& node);
|
||||
bool is_primitive_implemented_for_onednn(program_node& node);
|
||||
|
||||
@@ -106,6 +106,22 @@ bool layout_optimizer::onednn_check_data_types_for_convolution(data_types in_dt,
|
||||
return false;
|
||||
}
|
||||
|
||||
// almost same with onednn_check_data_types_for_convolution.
|
||||
// removed case
|
||||
// - in_dt(f16) wei_dt(f16) out_dt(f32)
|
||||
bool layout_optimizer::onednn_check_data_types_for_deconvolution(data_types in_dt, data_types wei_dt, data_types out_dt) {
|
||||
if ((in_dt == data_types::f16 && wei_dt == data_types::f16) &&
|
||||
(out_dt == data_types::f16 || out_dt == data_types::i8 || out_dt == data_types::u8))
|
||||
return true;
|
||||
if ((in_dt == data_types::i8 || in_dt == data_types::u8) && wei_dt == data_types::i8 &&
|
||||
(out_dt == data_types::f32 || out_dt == data_types::i32 || out_dt == data_types::f16 || out_dt == data_types::i8 || out_dt == data_types::u8))
|
||||
return true;
|
||||
if ((in_dt == data_types::f32 && wei_dt == data_types::f32) &&
|
||||
(out_dt == data_types::i8 || out_dt == data_types::u8))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool layout_optimizer::onednn_check_data_types_for_fc_gemm(data_types in_dt, data_types wei_dt, data_types out_dt) {
|
||||
if ((in_dt == data_types::f16 && wei_dt == data_types::f16) &&
|
||||
(out_dt == data_types::f16 || out_dt == data_types::f32 || out_dt == data_types::i8))
|
||||
@@ -1231,11 +1247,12 @@ bool layout_optimizer::are_data_types_suitable_for_onednn(program_node& node) {
|
||||
|
||||
if (node.is_type<pooling>()) {
|
||||
return onednn_check_data_types_for_pooling(in_dt, out_dt);
|
||||
} else if (node.is_type<convolution>() || node.is_type<deconvolution>()) {
|
||||
bool is_conv = node.is_type<convolution>();
|
||||
auto wei_dt = is_conv ? node.as<convolution>().weights().get_output_layout().data_type :
|
||||
node.as<deconvolution>().weights().get_output_layout().data_type;
|
||||
} else if (node.is_type<convolution>()) {
|
||||
auto wei_dt = node.as<convolution>().weights().get_output_layout().data_type;
|
||||
return onednn_check_data_types_for_convolution(in_dt, wei_dt, out_dt);
|
||||
} else if (node.is_type<deconvolution>()) {
|
||||
auto wei_dt = node.as<deconvolution>().weights().get_output_layout().data_type;
|
||||
return onednn_check_data_types_for_deconvolution(in_dt, wei_dt, out_dt);
|
||||
} else if (node.is_type<fully_connected>() || node.is_type<gemm>()) {
|
||||
bool is_fc = node.is_type<fully_connected>();
|
||||
auto wei_dt = is_fc ? node.as<fully_connected>().weights().get_output_layout().data_type :
|
||||
|
||||
Reference in New Issue
Block a user