[GPU] Add oneDNN FC preferred_format to bfyx (#15704)

Signed-off-by: hyunback <hyunback.kim@intel.com>
This commit is contained in:
hyunback kim 2023-02-24 15:19:54 +09:00 committed by GitHub
parent f562e96305
commit be5f90199d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 28 additions and 3 deletions

View File

@ -39,8 +39,9 @@ bool is_batch_after_spatial(const std::string order) {
}
format::type get_preferred_format(fully_connected_node const& node, const kernel_impl_params& impl_param) {
if (node.get_preferred_impl_type() == impl_types::onednn)
return format::bfyx;
if (node.get_preferred_impl_type() == impl_types::onednn && node.get_preferred_output_fmt() != format::any) {
return node.get_preferred_output_fmt();
}
auto input_layout = impl_param.get_input_layout();

View File

@ -78,6 +78,10 @@ layout gemm_inst::calc_output_layout(gemm_node const& node, kernel_impl_params c
auto output_format = input0_layout.format;
if (node.get_preferred_impl_type() == impl_types::onednn && node.get_preferred_output_fmt() != format::any) {
output_format = node.get_preferred_output_fmt();
}
return layout(output_shape, output_type, output_format, prim->output_paddings[0]);
}

View File

@ -5,6 +5,7 @@
#include "pass_manager.h"
#include "data_inst.h"
#include "mutable_data_inst.h"
#include "gemm_inst.h"
#include "program_node.h"
#include "intel_gpu/runtime/engine.hpp"
#include "intel_gpu/runtime/itt.hpp"
@ -44,6 +45,8 @@ void select_preferred_formats::run(program& p) {
dnnl::primitive_attr(),
dnnl::memory::format_tag::any);
_lo.select_preferred_formats_for_onednn(*n, *prim_desc);
} else if (n->is_type<fully_connected>() || n->is_type<gemm>()) {
_lo.select_preferred_formats_for_onednn(*n);
}
} catch(std::exception &exception) {
GPU_DEBUG_INFO << "WARNING(select_preferred_formats): " << exception.what() << std::endl;

View File

@ -210,7 +210,7 @@ public:
bool should_select_b_fs_yx_fsv16_layout(convolution_node const& node, layout const& output_or_weights_layout);
#ifdef ENABLE_ONEDNN_FOR_GPU
void select_preferred_formats_for_onednn(program_node& node, dnnl::primitive_desc prim_desc);
void select_preferred_formats_for_onednn(program_node& node, dnnl::primitive_desc prim_desc = dnnl::primitive_desc());
#endif
};
} // namespace cldnn

View File

@ -1757,6 +1757,10 @@ format layout_optimizer::get_preferred_format(program_node& node) {
// Set default format for issue 92967/98750
// TODO: will remove when arg_max_min_ref supports blocked format
expected = format::get_default_format(node.get_input_layouts()[0].get_rank(), false, false);
} else if (node.is_type<fully_connected>() || node.is_type<gemm>()) {
if (use_onednn_impls) {
expected = node.get_preferred_output_fmt();
}
}
if (allow_new_shape_infer && node.get_preferred_input_fmt() != format::any) {
@ -1862,6 +1866,19 @@ void layout_optimizer::select_preferred_formats_for_onednn(program_node& node, d
GPU_DEBUG_LOG << "select_preferred_formats:" << node.id() << ": " << fmt_to_str(src_fmt) << " --> " << fmt_to_str(dst_fmt)
<< " For index : " << idx << std::endl;
}
} else if (node.is_type<fully_connected>() || node.is_type<gemm>()) {
for (size_t idx = 0 ; idx < node.get_dependencies().size() ; idx++) {
if (node.get_dependency(idx).is_constant())
continue;
node.set_preferred_input_fmt(idx, cldnn::format::bfyx);
if (node.get_preferred_output_fmt() == format::any) {
for (size_t usr = 0; usr < std::max<size_t>(1, node.get_users().size()); usr++)
node.set_preferred_output_fmt(usr, cldnn::format::bfyx);
}
GPU_DEBUG_LOG << "select_preferred_formats:" << node.id() << ": " << fmt_to_str(cldnn::format::bfyx) << " --> " << fmt_to_str(cldnn::format::bfyx)
<< " For index : " << idx << std::endl;
}
}
return;