[GPU] Skip reorder_node_to_split to avoid change of input data type for ondenn kernel support (#16827)
* skip reorder_node_to_split when new input data type of onednn kernel is not supported * update layout_optimizer and add unit test
This commit is contained in:
@@ -5,6 +5,10 @@
|
||||
#include "pass_manager.h"
|
||||
#include "program_helpers.h"
|
||||
#include "reshape_inst.h"
|
||||
#include "layout_optimizer.h"
|
||||
|
||||
#include "gemm_inst.h"
|
||||
#include "pooling_inst.h"
|
||||
|
||||
#include <iterator>
|
||||
#include <vector>
|
||||
@@ -75,11 +79,48 @@ void handle_reshape::run(program& p) {
|
||||
// vector for storing nodes that are reorder type, for which splitted primitives are needed (except for the
|
||||
// first one where orginal reshape will be used)
|
||||
std::vector<program_node*> reorder_node_to_split;
|
||||
std::vector<program_node*> onednn_users;
|
||||
|
||||
// find the users of reshape that are reorder type, if none present then skip the current node
|
||||
// find users who are onednn impl
|
||||
for (const auto& user : node->get_users()) {
|
||||
if (user->is_type<reorder>())
|
||||
reorder_node_to_split.push_back(user);
|
||||
if (user->get_preferred_impl_type() == cldnn::impl_types::onednn)
|
||||
onednn_users.push_back(user);
|
||||
}
|
||||
|
||||
// If onednn user doesn't support new input data type from future "reorder:_reshape_input_" reorder,
|
||||
// remove target reorder_node to keep original datatype
|
||||
if (!onednn_users.empty() && !reorder_node_to_split.empty()) {
|
||||
// Copy reorder_node_to_split to iteration
|
||||
std::vector<program_node*> reorder_users(reorder_node_to_split);
|
||||
for (const auto& reorder_node : reorder_users) {
|
||||
auto output_data_type = reorder_node->get_output_layout().data_type;
|
||||
bool onednn_support = true;
|
||||
for (const auto& user : onednn_users) {
|
||||
auto out_dt = user->get_output_layout().data_type;
|
||||
if (user->is_type<fully_connected>() || user->is_type<gemm>()) {
|
||||
bool is_fc = user->is_type<fully_connected>();
|
||||
auto wei_dt = is_fc ? user->as<fully_connected>().weights().get_output_layout().data_type :
|
||||
user->as<gemm>().get_dependency(1).get_output_layout().data_type;
|
||||
onednn_support = layout_optimizer::onednn_check_data_types_for_fc_gemm(output_data_type, wei_dt, out_dt);
|
||||
} else if (user->is_type<convolution>() || user->is_type<deconvolution>()) {
|
||||
bool is_conv = user->is_type<convolution>();
|
||||
auto wei_dt = is_conv ? user->as<convolution>().weights().get_output_layout().data_type :
|
||||
user->as<deconvolution>().weights().get_output_layout().data_type;
|
||||
onednn_support = layout_optimizer::onednn_check_data_types_for_convolution(output_data_type, wei_dt, out_dt);
|
||||
} else if (user->is_type<pooling>()) {
|
||||
onednn_support = layout_optimizer::onednn_check_data_types_for_pooling(output_data_type, out_dt);
|
||||
}
|
||||
|
||||
if (!onednn_support) {
|
||||
reorder_node_to_split.erase(std::remove(reorder_node_to_split.begin(), reorder_node_to_split.end(), reorder_node),
|
||||
reorder_node_to_split.end());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!reorder_node_to_split.empty()) {
|
||||
|
||||
@@ -189,6 +189,9 @@ public:
|
||||
impl_types get_forced_impl_type_by_config(program_node& node);
|
||||
static bool are_data_types_suitable_for_onednn(program_node& node);
|
||||
bool are_layouts_suitable_for_onednn(program_node& node);
|
||||
static bool onednn_check_data_types_for_pooling(data_types in_dt, data_types out_dt);
|
||||
static bool onednn_check_data_types_for_convolution(data_types in_dt, data_types wei_dt, data_types out_dt);
|
||||
static bool onednn_check_data_types_for_fc_gemm(data_types in_dt, data_types wei_dt, data_types out_dt);
|
||||
bool is_primitive_implemented_for_onednn(program_node& node);
|
||||
bool is_format_supported(program_node& node, format::type fmt);
|
||||
|
||||
|
||||
@@ -77,6 +77,47 @@ static bool is_reduce_blocked_axes(reduce_node const& node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool layout_optimizer::onednn_check_data_types_for_pooling(data_types in_dt, data_types out_dt) {
|
||||
if (!data_type_traits::is_floating_point(in_dt) && in_dt != out_dt)
|
||||
return false;
|
||||
if ((in_dt == data_types::i8 || in_dt == data_types::u8) && out_dt != data_types::f32)
|
||||
return true;
|
||||
if (in_dt == data_types::f16 || out_dt == data_types::f16)
|
||||
return true;
|
||||
if (out_dt == data_types::f32)
|
||||
return true;
|
||||
if (in_dt == data_types::i32 || out_dt == data_types::i32)
|
||||
return true;
|
||||
if ((in_dt == data_types::i8 || out_dt == data_types::i8) || (in_dt == data_types::u8 || out_dt == data_types::u8))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool layout_optimizer::onednn_check_data_types_for_convolution(data_types in_dt, data_types wei_dt, data_types out_dt) {
|
||||
if ((in_dt == data_types::f16 && wei_dt == data_types::f16) &&
|
||||
(out_dt == data_types::f16 || out_dt == data_types::f32 || out_dt == data_types::i8 || out_dt == data_types::u8))
|
||||
return true;
|
||||
if ((in_dt == data_types::i8 || in_dt == data_types::u8) && wei_dt == data_types::i8 &&
|
||||
(out_dt == data_types::f32 || out_dt == data_types::i32 || out_dt == data_types::f16 || out_dt == data_types::i8 || out_dt == data_types::u8))
|
||||
return true;
|
||||
if ((in_dt == data_types::f32 && wei_dt == data_types::f32) &&
|
||||
(out_dt == data_types::i8 || out_dt == data_types::u8))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool layout_optimizer::onednn_check_data_types_for_fc_gemm(data_types in_dt, data_types wei_dt, data_types out_dt) {
|
||||
if ((in_dt == data_types::f16 && wei_dt == data_types::f16) &&
|
||||
(out_dt == data_types::f16 || out_dt == data_types::f32 || out_dt == data_types::i8))
|
||||
return true;
|
||||
if (in_dt == data_types::f32 && wei_dt == data_types::f32)
|
||||
return true;
|
||||
if ((in_dt == data_types::i8 || in_dt == data_types::u8) && (wei_dt == data_types::i8) &&
|
||||
(out_dt == data_types::i8 || out_dt == data_types::u8 || out_dt == data_types::i32 || out_dt == data_types::f16 || out_dt == data_types::f32))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::pair<std::shared_ptr<reorder>, bool> reorder_factory::get_reorder(primitive_id src_id,
|
||||
const layout& in_layout,
|
||||
const layout& out_layout) {
|
||||
@@ -1189,58 +1230,17 @@ bool layout_optimizer::are_data_types_suitable_for_onednn(program_node& node) {
|
||||
return false;
|
||||
|
||||
if (node.is_type<pooling>()) {
|
||||
if (!data_type_traits::is_floating_point(in_dt) && in_dt != out_dt)
|
||||
return false;
|
||||
if ((in_dt == data_types::i8 || in_dt == data_types::u8) && out_dt != data_types::f32)
|
||||
return true;
|
||||
if (in_dt == data_types::f16 || out_dt == data_types::f16)
|
||||
return true;
|
||||
if (out_dt == data_types::f32)
|
||||
return true;
|
||||
if (in_dt == data_types::i32 || out_dt == data_types::i32)
|
||||
return true;
|
||||
if ((in_dt == data_types::i8 || out_dt == data_types::i8) || (in_dt == data_types::u8 || out_dt == data_types::u8))
|
||||
return true;
|
||||
return onednn_check_data_types_for_pooling(in_dt, out_dt);
|
||||
} else if (node.is_type<convolution>() || node.is_type<deconvolution>()) {
|
||||
bool is_conv = node.is_type<convolution>();
|
||||
auto wei_dt = is_conv ? node.as<convolution>().weights().get_output_layout().data_type :
|
||||
node.as<deconvolution>().weights().get_output_layout().data_type;
|
||||
|
||||
if ((in_dt == data_types::f16 && wei_dt == data_types::f16) &&
|
||||
(out_dt == data_types::f16 || out_dt == data_types::f32 || out_dt == data_types::i8 || out_dt == data_types::u8))
|
||||
return true;
|
||||
if ((in_dt == data_types::i8 || in_dt == data_types::u8) && wei_dt == data_types::i8 &&
|
||||
(out_dt == data_types::f32 || out_dt == data_types::i32 || out_dt == data_types::f16 || out_dt == data_types::i8 || out_dt == data_types::u8))
|
||||
return true;
|
||||
if ((in_dt == data_types::f32 && wei_dt == data_types::f32) &&
|
||||
(out_dt == data_types::i8 || out_dt == data_types::u8))
|
||||
return true;
|
||||
return onednn_check_data_types_for_convolution(in_dt, wei_dt, out_dt);
|
||||
} else if (node.is_type<fully_connected>() || node.is_type<gemm>()) {
|
||||
bool is_fc = node.is_type<fully_connected>();
|
||||
auto wei_dt = is_fc ? node.as<fully_connected>().weights().get_output_layout().data_type :
|
||||
node.as<gemm>().get_dependency(1).get_output_layout().data_type;
|
||||
|
||||
if ((in_dt == data_types::f16 && wei_dt == data_types::f16) &&
|
||||
(out_dt == data_types::f16 || out_dt == data_types::f32 || out_dt == data_types::i8))
|
||||
return true;
|
||||
if (in_dt == data_types::f32 && wei_dt == data_types::f32)
|
||||
return true;
|
||||
if ((in_dt == data_types::i8 || in_dt == data_types::u8) && (wei_dt == data_types::i8) &&
|
||||
(out_dt == data_types::i8 || out_dt == data_types::u8 || out_dt == data_types::i32 || out_dt == data_types::f16 || out_dt == data_types::f32))
|
||||
return true;
|
||||
} else if (node.is_type<reorder>()) {
|
||||
auto input_fmt = node.get_dependency(0).get_output_layout().format;
|
||||
auto output_fmt = node.get_output_layout().format;
|
||||
|
||||
// For mixed precision case, oneDNN is slower than clDNN
|
||||
if (input_fmt == format::b_fs_yx_fsv16 && data_type_traits::is_i8_u8(in_dt))
|
||||
return false;
|
||||
if (output_fmt == format::b_fs_yx_fsv16 && data_type_traits::is_i8_u8(in_dt))
|
||||
return false;
|
||||
if (output_fmt == format::bfyx && out_dt == data_types::f32)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
return onednn_check_data_types_for_fc_gemm(in_dt, wei_dt, out_dt);
|
||||
}
|
||||
|
||||
return false;
|
||||
|
||||
@@ -50,3 +50,43 @@ TEST(handle_reshape, dont_remove_reshape_that_changes_rank) {
|
||||
|
||||
ASSERT_TRUE(prog->get_node("reshape").can_be_optimized());
|
||||
}
|
||||
|
||||
TEST(handle_reshape, skip_reorder_node_to_split_when_onndnn_not_support) {
|
||||
// Onednn FC does not support fp32 input, fp16 weight. In such case, we need to ignore reorder_split from handle_reshape pass
|
||||
auto& engine = get_test_engine();
|
||||
if (!engine.get_device_info().supports_immad)
|
||||
return;
|
||||
|
||||
auto input = engine.allocate_memory({ data_types::f16, format::bfyx, { 1, 9, 1, 1024} });
|
||||
auto data_01 = engine.allocate_memory({ data_types::f16, format::bfyx, { 1, 9, 1, 1024} });
|
||||
auto weights = engine.allocate_memory({ data_types::f16, format::bfyx, { 1, 1024, 1, 1024} });
|
||||
auto bias = engine.allocate_memory({ data_types::f16, format::bfyx, { 1, 1, 1, 1024} });
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("input", input->get_layout()));
|
||||
topology.add(data("data", data_01));
|
||||
topology.add(data("weights", weights));
|
||||
topology.add(data("bias", bias));
|
||||
topology.add(eltwise("e1", input_info("input"), input_info("data"), eltwise_mode::sum));
|
||||
topology.add(reshape("reshape", input_info("e1"), tensor(9, 1, 1, 1024), cldnn::reshape::reshape_mode::base));
|
||||
topology.add(reorder("convert_to_f32", input_info("reshape"), { data_types::f32, format::bfyx, { 9, 1, 1, 1024} }));
|
||||
topology.add(fully_connected("matmul", input_info("reshape"), "weights", "bias", cldnn::padding(), 3, 2));
|
||||
topology.add(reorder("convert_to_f32_matmul", input_info("matmul"), { data_types::f32, format::bfyx, { 9, 1, 1, 1024} }));
|
||||
topology.add(eltwise("e2", input_info("convert_to_f32"), input_info("convert_to_f32_matmul"), eltwise_mode::sum));
|
||||
|
||||
|
||||
ExecutionConfig config = get_test_default_config(engine);
|
||||
config.set_property(ov::intel_gpu::optimize_data(true));
|
||||
auto prog = program::build_program(engine, topology, config, false, true);
|
||||
|
||||
layout_optimizer lo(true);
|
||||
lo.set_optimization_attribute(layout_optimizer::optimization_attributes_type::use_onednn_impls, true);
|
||||
reorder_factory rf;
|
||||
|
||||
program_wrapper::apply_opt_pass<reorder_inputs>(*prog, lo, rf);
|
||||
program_wrapper::apply_opt_pass<handle_reshape>(*prog);
|
||||
|
||||
ASSERT_NE(prog, nullptr);
|
||||
|
||||
ASSERT_TRUE(prog->get_node("matmul").get_dependency(0).get_output_layout().data_type == data_types::f16);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user