[GPU] Forcing to use clDNN FC on small batch size (#11715)
+ forced to use clDNN FC due to perf drop Signed-off-by: byungilm <byungil.min@intel.com>
This commit is contained in:
@@ -818,6 +818,44 @@ static bool is_node_for_onednn(deconvolution_node const& node) {
|
||||
return onednn_valid_dt && onednn_valid_params && spatial_dims_num <= 3;
|
||||
}
|
||||
|
||||
|
||||
static bool is_node_for_onednn(program_node& node, fully_connected_node const& fc_node) {
|
||||
bool is_suitable_for_onednn = true;
|
||||
auto out_layout = node.get_output_layout();
|
||||
for (auto& fo : node.get_fused_primitives()) {
|
||||
if (fo.node->is_type<eltwise>()) {
|
||||
// FC checkings
|
||||
auto in_layout = node.get_dependency(fo.dep_start_idx).get_output_layout();
|
||||
auto in_dt = in_layout.data_type;
|
||||
auto out_dt = out_layout.data_type;
|
||||
// if it is not eltwise sum and input is full tensor
|
||||
if ((out_layout.count() == in_layout.count()) && in_dt != out_dt
|
||||
&& (data_type_traits::is_floating_point(in_dt) || data_type_traits::is_floating_point(out_dt))
|
||||
&& onednn_add_fusing_helpers::is_full_tensor(in_layout)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// WA: onednn sum/binary_add post-op are not supported due to perf drop.
|
||||
auto add_type = onednn_add_fusing_helpers::get_add_fusing_type(node, fo);
|
||||
if (add_type == add_fusing_type::sum || add_type == add_fusing_type::binary_per_tensor || add_type == add_fusing_type::binary_per_oc) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto fc_prim = fc_node.get_primitive();
|
||||
size_t rank = cldnn::format::dimension(out_layout.format);
|
||||
auto size = out_layout.size;
|
||||
// OneDnn doesn't support spatial dimensions for output
|
||||
for (int i = 0; i < rank - 2 - (fc_prim->input_size == 3 ? 1 : 0); i++) {
|
||||
if (size.spatial[i] != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return is_suitable_for_onednn;
|
||||
}
|
||||
|
||||
bool layout_optimizer::needs_all_usr_onednn_small_ic_to_blocked(const program_node& node) {
|
||||
bool all_users_match = true;
|
||||
for (auto usr : node.get_users()) {
|
||||
@@ -1538,33 +1576,20 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format
|
||||
impl_candidate = impl_types::ocl;
|
||||
}
|
||||
|
||||
for (auto& fo : node.get_fused_primitives()) {
|
||||
if (fo.node->is_type<eltwise>()) {
|
||||
// FC checkings
|
||||
if (node.is_type<fully_connected>()) {
|
||||
auto in_layout = node.get_dependency(fo.dep_start_idx).get_output_layout();
|
||||
auto out_layout = node.get_output_layout();
|
||||
auto in_dt = in_layout.data_type;
|
||||
auto out_dt = out_layout.data_type;
|
||||
// if it is not eltwise sum and input is full tensor
|
||||
if ((out_layout.count() == in_layout.count()) && in_dt != out_dt
|
||||
&& (data_type_traits::is_floating_point(in_dt) || data_type_traits::is_floating_point(out_dt))
|
||||
&& onednn_add_fusing_helpers::is_full_tensor(in_layout)) {
|
||||
impl_candidate = impl_types::ocl;
|
||||
break;
|
||||
}
|
||||
if (node.is_type<fully_connected>()) {
|
||||
if (!is_node_for_onednn(node, node.as<fully_connected>()))
|
||||
impl_candidate = impl_types::ocl;
|
||||
|
||||
// WA: onednn sum/binary_add post-op are not supported due to perf drop.
|
||||
auto add_type = onednn_add_fusing_helpers::get_add_fusing_type(node, fo);
|
||||
if (add_type == add_fusing_type::sum || add_type == add_fusing_type::binary_per_tensor || add_type == add_fusing_type::binary_per_oc) {
|
||||
impl_candidate = impl_types::ocl;
|
||||
break;
|
||||
}
|
||||
// Gemm checkings
|
||||
// TODO: investigate why currently onednn gemm has some "sum" post-op restrictions
|
||||
// which don't correlate with fc checkings in the code above
|
||||
// Temprorary WA: disable onednn gemm with sum post-op inside
|
||||
} else {
|
||||
// WA : Use cldnn FC due to perf drop of small batch size until onednn FC improve perf
|
||||
if (node.get_output_layout().batch() < 32)
|
||||
impl_candidate = impl_types::ocl;
|
||||
} else {
|
||||
for (auto& fo : node.get_fused_primitives()) {
|
||||
if (fo.node->is_type<eltwise>()) {
|
||||
// Gemm checkings
|
||||
// TODO: investigate why currently onednn gemm has some "sum" post-op restrictions
|
||||
// which don't correlate with fc checkings in the code above
|
||||
// Temprorary WA: disable onednn gemm with sum post-op inside
|
||||
auto& e_node = fo.node->as<eltwise>();
|
||||
if (e_node.get_primitive()->mode == eltwise_mode::sum) {
|
||||
impl_candidate = impl_types::ocl;
|
||||
@@ -1572,21 +1597,7 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (node.is_type<fully_connected>()) {
|
||||
auto fc_prim = node.as<fully_connected>().get_primitive();
|
||||
auto out_layout = node.get_output_layout();
|
||||
size_t rank = cldnn::format::dimension(out_layout.format);
|
||||
auto size = out_layout.size;
|
||||
// OneDnn doesn't support spatial dimensions for output
|
||||
for (int i = 0; i < rank - 2 - (fc_prim->input_size == 3 ? 1 : 0); i++) {
|
||||
if (size.spatial[i] != 1) {
|
||||
impl_candidate = impl_types::ocl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
impl_candidate = impl_types::ocl;
|
||||
auto gemm_prim = node.as<gemm>().get_primitive();
|
||||
auto in0_l = node.get_dependency(0).get_output_layout();
|
||||
|
||||
Reference in New Issue
Block a user