[GPU] Add oneDNN FC preferred_format to bfyx (#15704)
Signed-off-by: hyunback <hyunback.kim@intel.com>
This commit is contained in:
parent
f562e96305
commit
be5f90199d
@ -39,8 +39,9 @@ bool is_batch_after_spatial(const std::string order) {
|
||||
}
|
||||
|
||||
format::type get_preferred_format(fully_connected_node const& node, const kernel_impl_params& impl_param) {
|
||||
if (node.get_preferred_impl_type() == impl_types::onednn)
|
||||
return format::bfyx;
|
||||
if (node.get_preferred_impl_type() == impl_types::onednn && node.get_preferred_output_fmt() != format::any) {
|
||||
return node.get_preferred_output_fmt();
|
||||
}
|
||||
|
||||
auto input_layout = impl_param.get_input_layout();
|
||||
|
||||
|
@ -78,6 +78,10 @@ layout gemm_inst::calc_output_layout(gemm_node const& node, kernel_impl_params c
|
||||
|
||||
auto output_format = input0_layout.format;
|
||||
|
||||
if (node.get_preferred_impl_type() == impl_types::onednn && node.get_preferred_output_fmt() != format::any) {
|
||||
output_format = node.get_preferred_output_fmt();
|
||||
}
|
||||
|
||||
return layout(output_shape, output_type, output_format, prim->output_paddings[0]);
|
||||
}
|
||||
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include "pass_manager.h"
|
||||
#include "data_inst.h"
|
||||
#include "mutable_data_inst.h"
|
||||
#include "gemm_inst.h"
|
||||
#include "program_node.h"
|
||||
#include "intel_gpu/runtime/engine.hpp"
|
||||
#include "intel_gpu/runtime/itt.hpp"
|
||||
@ -44,6 +45,8 @@ void select_preferred_formats::run(program& p) {
|
||||
dnnl::primitive_attr(),
|
||||
dnnl::memory::format_tag::any);
|
||||
_lo.select_preferred_formats_for_onednn(*n, *prim_desc);
|
||||
} else if (n->is_type<fully_connected>() || n->is_type<gemm>()) {
|
||||
_lo.select_preferred_formats_for_onednn(*n);
|
||||
}
|
||||
} catch(std::exception &exception) {
|
||||
GPU_DEBUG_INFO << "WARNING(select_preferred_formats): " << exception.what() << std::endl;
|
||||
|
@ -210,7 +210,7 @@ public:
|
||||
bool should_select_b_fs_yx_fsv16_layout(convolution_node const& node, layout const& output_or_weights_layout);
|
||||
|
||||
#ifdef ENABLE_ONEDNN_FOR_GPU
|
||||
void select_preferred_formats_for_onednn(program_node& node, dnnl::primitive_desc prim_desc);
|
||||
void select_preferred_formats_for_onednn(program_node& node, dnnl::primitive_desc prim_desc = dnnl::primitive_desc());
|
||||
#endif
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -1757,6 +1757,10 @@ format layout_optimizer::get_preferred_format(program_node& node) {
|
||||
// Set default format for issue 92967/98750
|
||||
// TODO: will remove when arg_max_min_ref supports blocked format
|
||||
expected = format::get_default_format(node.get_input_layouts()[0].get_rank(), false, false);
|
||||
} else if (node.is_type<fully_connected>() || node.is_type<gemm>()) {
|
||||
if (use_onednn_impls) {
|
||||
expected = node.get_preferred_output_fmt();
|
||||
}
|
||||
}
|
||||
|
||||
if (allow_new_shape_infer && node.get_preferred_input_fmt() != format::any) {
|
||||
@ -1862,6 +1866,19 @@ void layout_optimizer::select_preferred_formats_for_onednn(program_node& node, d
|
||||
GPU_DEBUG_LOG << "select_preferred_formats:" << node.id() << ": " << fmt_to_str(src_fmt) << " --> " << fmt_to_str(dst_fmt)
|
||||
<< " For index : " << idx << std::endl;
|
||||
}
|
||||
} else if (node.is_type<fully_connected>() || node.is_type<gemm>()) {
|
||||
for (size_t idx = 0 ; idx < node.get_dependencies().size() ; idx++) {
|
||||
if (node.get_dependency(idx).is_constant())
|
||||
continue;
|
||||
node.set_preferred_input_fmt(idx, cldnn::format::bfyx);
|
||||
|
||||
if (node.get_preferred_output_fmt() == format::any) {
|
||||
for (size_t usr = 0; usr < std::max<size_t>(1, node.get_users().size()); usr++)
|
||||
node.set_preferred_output_fmt(usr, cldnn::format::bfyx);
|
||||
}
|
||||
GPU_DEBUG_LOG << "select_preferred_formats:" << node.id() << ": " << fmt_to_str(cldnn::format::bfyx) << " --> " << fmt_to_str(cldnn::format::bfyx)
|
||||
<< " For index : " << idx << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
|
Loading…
Reference in New Issue
Block a user