[GPU] Skip reorder_node_to_split to avoid change of input data type for ondenn kernel support (#16827)

* skip reorder_node_to_split when new input data type of onednn kernel is not supported
* update layout_optimizer and add unit test
This commit is contained in:
Wilson Seok
2023-04-19 15:00:55 +09:00
committed by GitHub
parent 1281074e15
commit 2401b0aa3c
4 changed files with 128 additions and 44 deletions

View File

@@ -5,6 +5,10 @@
#include "pass_manager.h"
#include "program_helpers.h"
#include "reshape_inst.h"
#include "layout_optimizer.h"
#include "gemm_inst.h"
#include "pooling_inst.h"
#include <iterator>
#include <vector>
@@ -75,11 +79,48 @@ void handle_reshape::run(program& p) {
// vector for storing nodes that are reorder type, for which splitted primitives are needed (except for the
// first one where orginal reshape will be used)
std::vector<program_node*> reorder_node_to_split;
std::vector<program_node*> onednn_users;
// find the users of reshape that are reorder type, if none present then skip the current node
// find users who are onednn impl
for (const auto& user : node->get_users()) {
if (user->is_type<reorder>())
reorder_node_to_split.push_back(user);
if (user->get_preferred_impl_type() == cldnn::impl_types::onednn)
onednn_users.push_back(user);
}
// If onednn user doesn't support new input data type from future "reorder:_reshape_input_" reorder,
// remove target reorder_node to keep original datatype
if (!onednn_users.empty() && !reorder_node_to_split.empty()) {
// Copy reorder_node_to_split to iteration
std::vector<program_node*> reorder_users(reorder_node_to_split);
for (const auto& reorder_node : reorder_users) {
auto output_data_type = reorder_node->get_output_layout().data_type;
bool onednn_support = true;
for (const auto& user : onednn_users) {
auto out_dt = user->get_output_layout().data_type;
if (user->is_type<fully_connected>() || user->is_type<gemm>()) {
bool is_fc = user->is_type<fully_connected>();
auto wei_dt = is_fc ? user->as<fully_connected>().weights().get_output_layout().data_type :
user->as<gemm>().get_dependency(1).get_output_layout().data_type;
onednn_support = layout_optimizer::onednn_check_data_types_for_fc_gemm(output_data_type, wei_dt, out_dt);
} else if (user->is_type<convolution>() || user->is_type<deconvolution>()) {
bool is_conv = user->is_type<convolution>();
auto wei_dt = is_conv ? user->as<convolution>().weights().get_output_layout().data_type :
user->as<deconvolution>().weights().get_output_layout().data_type;
onednn_support = layout_optimizer::onednn_check_data_types_for_convolution(output_data_type, wei_dt, out_dt);
} else if (user->is_type<pooling>()) {
onednn_support = layout_optimizer::onednn_check_data_types_for_pooling(output_data_type, out_dt);
}
if (!onednn_support) {
reorder_node_to_split.erase(std::remove(reorder_node_to_split.begin(), reorder_node_to_split.end(), reorder_node),
reorder_node_to_split.end());
break;
}
}
}
}
if (!reorder_node_to_split.empty()) {

View File

@@ -189,6 +189,9 @@ public:
impl_types get_forced_impl_type_by_config(program_node& node);
static bool are_data_types_suitable_for_onednn(program_node& node);
bool are_layouts_suitable_for_onednn(program_node& node);
static bool onednn_check_data_types_for_pooling(data_types in_dt, data_types out_dt);
static bool onednn_check_data_types_for_convolution(data_types in_dt, data_types wei_dt, data_types out_dt);
static bool onednn_check_data_types_for_fc_gemm(data_types in_dt, data_types wei_dt, data_types out_dt);
bool is_primitive_implemented_for_onednn(program_node& node);
bool is_format_supported(program_node& node, format::type fmt);

View File

@@ -77,6 +77,47 @@ static bool is_reduce_blocked_axes(reduce_node const& node) {
return false;
}
bool layout_optimizer::onednn_check_data_types_for_pooling(data_types in_dt, data_types out_dt) {
if (!data_type_traits::is_floating_point(in_dt) && in_dt != out_dt)
return false;
if ((in_dt == data_types::i8 || in_dt == data_types::u8) && out_dt != data_types::f32)
return true;
if (in_dt == data_types::f16 || out_dt == data_types::f16)
return true;
if (out_dt == data_types::f32)
return true;
if (in_dt == data_types::i32 || out_dt == data_types::i32)
return true;
if ((in_dt == data_types::i8 || out_dt == data_types::i8) || (in_dt == data_types::u8 || out_dt == data_types::u8))
return true;
return false;
}
bool layout_optimizer::onednn_check_data_types_for_convolution(data_types in_dt, data_types wei_dt, data_types out_dt) {
if ((in_dt == data_types::f16 && wei_dt == data_types::f16) &&
(out_dt == data_types::f16 || out_dt == data_types::f32 || out_dt == data_types::i8 || out_dt == data_types::u8))
return true;
if ((in_dt == data_types::i8 || in_dt == data_types::u8) && wei_dt == data_types::i8 &&
(out_dt == data_types::f32 || out_dt == data_types::i32 || out_dt == data_types::f16 || out_dt == data_types::i8 || out_dt == data_types::u8))
return true;
if ((in_dt == data_types::f32 && wei_dt == data_types::f32) &&
(out_dt == data_types::i8 || out_dt == data_types::u8))
return true;
return false;
}
bool layout_optimizer::onednn_check_data_types_for_fc_gemm(data_types in_dt, data_types wei_dt, data_types out_dt) {
if ((in_dt == data_types::f16 && wei_dt == data_types::f16) &&
(out_dt == data_types::f16 || out_dt == data_types::f32 || out_dt == data_types::i8))
return true;
if (in_dt == data_types::f32 && wei_dt == data_types::f32)
return true;
if ((in_dt == data_types::i8 || in_dt == data_types::u8) && (wei_dt == data_types::i8) &&
(out_dt == data_types::i8 || out_dt == data_types::u8 || out_dt == data_types::i32 || out_dt == data_types::f16 || out_dt == data_types::f32))
return true;
return false;
}
std::pair<std::shared_ptr<reorder>, bool> reorder_factory::get_reorder(primitive_id src_id,
const layout& in_layout,
const layout& out_layout) {
@@ -1189,58 +1230,17 @@ bool layout_optimizer::are_data_types_suitable_for_onednn(program_node& node) {
return false;
if (node.is_type<pooling>()) {
if (!data_type_traits::is_floating_point(in_dt) && in_dt != out_dt)
return false;
if ((in_dt == data_types::i8 || in_dt == data_types::u8) && out_dt != data_types::f32)
return true;
if (in_dt == data_types::f16 || out_dt == data_types::f16)
return true;
if (out_dt == data_types::f32)
return true;
if (in_dt == data_types::i32 || out_dt == data_types::i32)
return true;
if ((in_dt == data_types::i8 || out_dt == data_types::i8) || (in_dt == data_types::u8 || out_dt == data_types::u8))
return true;
return onednn_check_data_types_for_pooling(in_dt, out_dt);
} else if (node.is_type<convolution>() || node.is_type<deconvolution>()) {
bool is_conv = node.is_type<convolution>();
auto wei_dt = is_conv ? node.as<convolution>().weights().get_output_layout().data_type :
node.as<deconvolution>().weights().get_output_layout().data_type;
if ((in_dt == data_types::f16 && wei_dt == data_types::f16) &&
(out_dt == data_types::f16 || out_dt == data_types::f32 || out_dt == data_types::i8 || out_dt == data_types::u8))
return true;
if ((in_dt == data_types::i8 || in_dt == data_types::u8) && wei_dt == data_types::i8 &&
(out_dt == data_types::f32 || out_dt == data_types::i32 || out_dt == data_types::f16 || out_dt == data_types::i8 || out_dt == data_types::u8))
return true;
if ((in_dt == data_types::f32 && wei_dt == data_types::f32) &&
(out_dt == data_types::i8 || out_dt == data_types::u8))
return true;
return onednn_check_data_types_for_convolution(in_dt, wei_dt, out_dt);
} else if (node.is_type<fully_connected>() || node.is_type<gemm>()) {
bool is_fc = node.is_type<fully_connected>();
auto wei_dt = is_fc ? node.as<fully_connected>().weights().get_output_layout().data_type :
node.as<gemm>().get_dependency(1).get_output_layout().data_type;
if ((in_dt == data_types::f16 && wei_dt == data_types::f16) &&
(out_dt == data_types::f16 || out_dt == data_types::f32 || out_dt == data_types::i8))
return true;
if (in_dt == data_types::f32 && wei_dt == data_types::f32)
return true;
if ((in_dt == data_types::i8 || in_dt == data_types::u8) && (wei_dt == data_types::i8) &&
(out_dt == data_types::i8 || out_dt == data_types::u8 || out_dt == data_types::i32 || out_dt == data_types::f16 || out_dt == data_types::f32))
return true;
} else if (node.is_type<reorder>()) {
auto input_fmt = node.get_dependency(0).get_output_layout().format;
auto output_fmt = node.get_output_layout().format;
// For mixed precision case, oneDNN is slower than clDNN
if (input_fmt == format::b_fs_yx_fsv16 && data_type_traits::is_i8_u8(in_dt))
return false;
if (output_fmt == format::b_fs_yx_fsv16 && data_type_traits::is_i8_u8(in_dt))
return false;
if (output_fmt == format::bfyx && out_dt == data_types::f32)
return false;
return true;
return onednn_check_data_types_for_fc_gemm(in_dt, wei_dt, out_dt);
}
return false;

View File

@@ -50,3 +50,43 @@ TEST(handle_reshape, dont_remove_reshape_that_changes_rank) {
ASSERT_TRUE(prog->get_node("reshape").can_be_optimized());
}
TEST(handle_reshape, skip_reorder_node_to_split_when_onndnn_not_support) {
// Onednn FC does not support fp32 input, fp16 weight. In such case, we need to ignore reorder_split from handle_reshape pass
auto& engine = get_test_engine();
if (!engine.get_device_info().supports_immad)
return;
auto input = engine.allocate_memory({ data_types::f16, format::bfyx, { 1, 9, 1, 1024} });
auto data_01 = engine.allocate_memory({ data_types::f16, format::bfyx, { 1, 9, 1, 1024} });
auto weights = engine.allocate_memory({ data_types::f16, format::bfyx, { 1, 1024, 1, 1024} });
auto bias = engine.allocate_memory({ data_types::f16, format::bfyx, { 1, 1, 1, 1024} });
topology topology;
topology.add(input_layout("input", input->get_layout()));
topology.add(data("data", data_01));
topology.add(data("weights", weights));
topology.add(data("bias", bias));
topology.add(eltwise("e1", input_info("input"), input_info("data"), eltwise_mode::sum));
topology.add(reshape("reshape", input_info("e1"), tensor(9, 1, 1, 1024), cldnn::reshape::reshape_mode::base));
topology.add(reorder("convert_to_f32", input_info("reshape"), { data_types::f32, format::bfyx, { 9, 1, 1, 1024} }));
topology.add(fully_connected("matmul", input_info("reshape"), "weights", "bias", cldnn::padding(), 3, 2));
topology.add(reorder("convert_to_f32_matmul", input_info("matmul"), { data_types::f32, format::bfyx, { 9, 1, 1, 1024} }));
topology.add(eltwise("e2", input_info("convert_to_f32"), input_info("convert_to_f32_matmul"), eltwise_mode::sum));
ExecutionConfig config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::optimize_data(true));
auto prog = program::build_program(engine, topology, config, false, true);
layout_optimizer lo(true);
lo.set_optimization_attribute(layout_optimizer::optimization_attributes_type::use_onednn_impls, true);
reorder_factory rf;
program_wrapper::apply_opt_pass<reorder_inputs>(*prog, lo, rf);
program_wrapper::apply_opt_pass<handle_reshape>(*prog);
ASSERT_NE(prog, nullptr);
ASSERT_TRUE(prog->get_node("matmul").get_dependency(0).get_output_layout().data_type == data_types::f16);
}