From d06a22f4e41c9e36cfb5899981fa5d246d5f10ec Mon Sep 17 00:00:00 2001 From: hyunback kim Date: Tue, 28 Mar 2023 14:49:49 +0900 Subject: [PATCH] [GPU] Support FC+eltwise fusion in fp16 for OneDNN (#16303) * [GPU] Support FC+eltwise fusion in fp16 Signed-off-by: hyunback --- .../prepare_primitive_fusing.cpp | 12 +- .../graph/graph_optimizer/reorder_inputs.cpp | 23 ++++ .../fusions/fully_connected_fusion_test.cpp | 110 +++++++++++++++++- 3 files changed, 138 insertions(+), 7 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp index 938b599e85c..457c16a6fe5 100644 --- a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp +++ b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp @@ -558,10 +558,14 @@ void prepare_primitive_fusing::fuse_simple_primitives(program &p) { return false; }; - auto fc_supports_fusings = [](fully_connected_node& node) -> bool { - auto in_dt = node.get_dependency(0).get_output_layout().data_type; - - return data_type_traits::is_i8_u8(in_dt); + auto fc_supports_fusings = [&](fully_connected_node& node) -> bool { + if (_lo.get_optimization_attributes().use_onednn_impls && + _lo.get_preferred_impl_type(node, format::any /*dummy*/) == impl_types::onednn) { + return true; + } else { + auto in_dt = node.get_dependency(0).get_output_layout().data_type; + return data_type_traits::is_i8_u8(in_dt); + } }; auto gemm_supports_fusings = [](gemm_node& node) -> bool { diff --git a/src/plugins/intel_gpu/src/graph/graph_optimizer/reorder_inputs.cpp b/src/plugins/intel_gpu/src/graph/graph_optimizer/reorder_inputs.cpp index 3f05f294294..66c87f99816 100644 --- a/src/plugins/intel_gpu/src/graph/graph_optimizer/reorder_inputs.cpp +++ b/src/plugins/intel_gpu/src/graph/graph_optimizer/reorder_inputs.cpp @@ -1008,6 +1008,29 @@ void reorder_inputs::run(program& p, layout_optimizer& lo, reorder_factory& rf) const auto prim_id = "broadcast:" + data.id() + "_broadcasted" + std::to_string(idx++); auto broadcast_prim = std::make_shared(prim_id, cldnn::input_info(data.id()), gemm_layout.get_shape(), ov::AxisSet{}); + auto& broadcast_node = p.get_or_create(broadcast_prim); + p.add_intermediate(broadcast_node, *node, fused_prim.dep_start_idx, true); + broadcast_node.recalc_output_layouts(false); + } + } + } else if (node->is_type() && node->get_preferred_impl_type() == impl_types::onednn) { + for (const auto& fused_prim : node->get_fused_primitives()) { + if (fused_prim.is_type() && + one_of(fused_prim.typed_desc()->mode, {eltwise_mode::sum, eltwise_mode::sub, eltwise_mode::prod})) { + auto fc_layout = node->get_output_layout(); + auto& data = node->get_dependency(fused_prim.dep_start_idx); + auto data_layout = data.get_output_layout(); + + if ((fc_layout.batch() == 1 || fc_layout.feature() == 1) || + (data_layout.batch() == 1 && data_layout.feature() == 1) || + (fc_layout.count() == data_layout.count())) { + continue; + } + + static size_t idx = 0; + const auto prim_id = "broadcast:" + data.id() + "_broadcasted" + std::to_string(idx++); + auto broadcast_prim = std::make_shared(prim_id, cldnn::input_info(data.id()), fc_layout.get_shape(), ov::AxisSet{}); + auto& broadcast_node = p.get_or_create(broadcast_prim); p.add_intermediate(broadcast_node, *node, fused_prim.dep_start_idx, true); broadcast_node.recalc_output_layouts(false); diff --git a/src/plugins/intel_gpu/tests/fusions/fully_connected_fusion_test.cpp b/src/plugins/intel_gpu/tests/fusions/fully_connected_fusion_test.cpp index c0c810043f3..1c758c1988c 100644 --- a/src/plugins/intel_gpu/tests/fusions/fully_connected_fusion_test.cpp +++ b/src/plugins/intel_gpu/tests/fusions/fully_connected_fusion_test.cpp @@ -30,6 +30,7 @@ struct fully_connected_test_params { format default_format; size_t expected_fused_primitives; size_t expected_not_fused_primitives; + std::string ocl_kernel_name; // for onednn test }; class FullyConnectedFusingTest : public ::BaseFusingTest { @@ -85,14 +86,23 @@ public: auto input_prim = p.data_type == data_types::u8 ? get_mem(get_input_layout(p), 0, 10) : get_mem(get_input_layout(p)); auto impl_forcing = cfg_fused.get_property(ov::intel_gpu::force_implementations); - auto forcing_format = p.input_format; for (auto& forcing : impl_forcing) if (forcing.first == "fc_prim") forcing_format = forcing.second.output_format; - ov::intel_gpu::ImplementationDesc conv_impl = { forcing_format, "", impl_types::onednn }; - cfg_fused.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ { "fc_prim", conv_impl } })); + ov::intel_gpu::ImplementationDesc fc_impl = { forcing_format, "", impl_types::onednn }; + cfg_fused.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ { "fc_prim", fc_impl } })); + + if (!p.ocl_kernel_name.empty()) { + auto ocl_impl_forcing = cfg_not_fused.get_property(ov::intel_gpu::force_implementations); + auto ocl_forcing_format = p.input_format; + for (auto& forcing : ocl_impl_forcing) + if (forcing.first == "fc_prim") + ocl_forcing_format = forcing.second.output_format; + ov::intel_gpu::ImplementationDesc fc_ocl_impl = { ocl_forcing_format, p.ocl_kernel_name /*fully_connected_gpu_bfyx_ref*/}; + cfg_not_fused.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ { "fc_prim", fc_ocl_impl } })); + } network network_not_fused(this->engine, this->topology_non_fused, cfg_not_fused); network network_fused(this->engine, this->topology_fused, cfg_fused); network_fused.set_input_data("input", input_prim); @@ -154,6 +164,16 @@ public: #define CASE_FC_U8S8_3D_3 { 2, 3, 1 }, { 2, 3, 15 }, { 15, 1, 1 }, data_types::u8, format::bfyx, data_types::i8, format::oiyx, data_types::f32, format::bfyx #define CASE_FC_U8S8_3D_4 { 1, 512, 1024 }, { 1, 384, 1024 }, { 1024, 1024, 1 }, data_types::u8, format::bfyx, data_types::i8, format::oiyx, data_types::f32, format::bfyx +#define CASE_FC_FP16_1 { 1, 3 }, { 1, 4 }, { 4, 3 }, data_types::f16, format::bfyx, data_types::f16, format::oiyx, data_types::f32, format::bfyx +#define CASE_FC_FP16_2 { 2, 3 }, { 2, 4 }, { 4, 3 }, data_types::f16, format::bfyx, data_types::f16, format::oiyx, data_types::f32, format::bfyx +#define CASE_FC_FP16_3 { 2, 32 }, { 2, 16 }, { 16, 32 }, data_types::f16, format::bfyx, data_types::f16, format::oiyx, data_types::f32, format::bfyx +#define CASE_FC_FP16_4 { 128, 76 }, { 128, 768 }, { 768, 76 }, data_types::f16, format::bfyx, data_types::f16, format::oiyx, data_types::f32, format::bfyx +#define CASE_FC_FP16_5 { 1, 128, 76 }, { 1, 128, 768 }, { 1, 768, 76 }, data_types::f16, format::bfyx, data_types::f16, format::oiyx, data_types::f32, format::bfyx +#define CASE_FC_FP16_6 { 2, 1, 76 }, { 2, 1, 768 }, { 768, 76, 1 }, data_types::f16, format::bfyx, data_types::f16, format::oiyx, data_types::f32, format::bfyx +#define CASE_FC_FP16_7 { 2, 128, 76 }, { 2, 128, 768 }, { 768, 76, 1 }, data_types::f16, format::bfyx, data_types::f16, format::oiyx, data_types::f32, format::bfyx +#define CASE_FC_FP16_3D_1 { 2, 32, 3 }, { 2, 32, 16 }, { 16, 3, 1 }, data_types::f16, format::bfyx, data_types::f16, format::oiyx, data_types::f32, format::bfyx +#define CASE_FC_FP16_3D_2 { 1, 1, 3 }, { 1, 1, 32 }, { 32, 3, 1 }, data_types::f16, format::bfyx, data_types::f16, format::oiyx, data_types::f32, format::bfyx + /* ----------------------------------------------------------------------------------------------------- */ /* ---------------------------------------- FC cases --------------------------------------------------- */ /* ----------------------------------------------------------------------------------------------------- */ @@ -429,4 +449,88 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, fc_int8_inputs_fused_fp32_sum, ::testing:: // fully_connected_test_params{ CASE_FC_U8S8_3D_2, 2, 4 }, fully_connected_test_params{ CASE_FC_U8S8_3D_4, 2, 4 }, })); + + +class fc_fp16_eltwise_add : public FullyConnectedFusingTestOneDNN {}; +TEST_P(fc_fp16_eltwise_add, basic) { + auto p = GetParam(); + create_topologies( + input_layout("input", get_input_layout(p)), + data("weights", get_mem(get_weights_layout(p))), + data("bias", get_mem(get_bias_layout(p))), + data("eltwise_data", get_mem(get_per_channel_layout(p), 1, 9)), + fully_connected("fc_prim", input_info("input"), "weights", "bias", padding(), get_output_dim_size(p)), + eltwise("eltwise", { input_info("fc_prim"), input_info("eltwise_data") }, eltwise_mode::sum), + reorder("reorder_bfyx", input_info("eltwise"), p.default_format, data_types::f32) + ); + + tolerance = 1e-2f; + execute(p); +} + +INSTANTIATE_TEST_SUITE_P(fusings_gpu, fc_fp16_eltwise_add, ::testing::ValuesIn(std::vector{ + // fully_connected_test_params{ CASE_FC_FP16_1, 2, 3, "fully_connected_gpu_bs_f_bsv16_b1"}, // TODO check a failure in fully_connected_gpu_bs_f_bsv16_b1 + eltwise in iGPU + // fully_connected_test_params{ CASE_FC_FP16_3D_3, 2, 3, "fully_connected_gpu_bfyx_ref"}, // TODO check onednn failure + fully_connected_test_params{ CASE_FC_FP16_1, 2, 3, "fully_connected_gpu_bfyx_ref" }, + fully_connected_test_params{ CASE_FC_FP16_2, 2, 3, "fully_connected_gpu_bfyx_ref" }, + fully_connected_test_params{ CASE_FC_FP16_3, 2, 3, "fully_connected_gpu_bfyx_ref" }, + fully_connected_test_params{ CASE_FC_FP16_4, 2, 3, "fully_connected_gpu_bfyx_ref" }, + fully_connected_test_params{ CASE_FC_FP16_5, 2, 3, "fully_connected_gpu_bfyx_ref" }, + fully_connected_test_params{ CASE_FC_FP16_6, 2, 3, "fully_connected_gpu_bfyx_ref" }, + fully_connected_test_params{ CASE_FC_FP16_7, 2, 3, "fully_connected_gpu_bfyx_ref" }, + fully_connected_test_params{ CASE_FC_FP16_3D_1, 2, 3, "fully_connected_gpu_bfyx_ref" }, + fully_connected_test_params{ CASE_FC_FP16_3D_2, 2, 3, "fully_connected_gpu_bfyx_ref" }, +})); + +class fc_fp16_eltwise_sub : public FullyConnectedFusingTestOneDNN {}; +TEST_P(fc_fp16_eltwise_sub, basic) { + auto p = GetParam(); + create_topologies( + input_layout("input", get_input_layout(p)), + data("weights", get_mem(get_weights_layout(p))), + data("bias", get_mem(get_bias_layout(p))), + data("eltwise_data", get_mem(get_per_channel_layout(p), 1, 9)), + fully_connected("fc_prim", input_info("input"), "weights", "bias", padding(), get_output_dim_size(p)), + eltwise("eltwise", { input_info("fc_prim"), input_info("eltwise_data") }, eltwise_mode::sub), + reorder("reorder_bfyx", input_info("eltwise"), p.default_format, data_types::f32) + ); + + tolerance = 1e-1f; + execute(p); +} + +INSTANTIATE_TEST_SUITE_P(fusings_gpu, fc_fp16_eltwise_sub, ::testing::ValuesIn(std::vector{ + fully_connected_test_params{ CASE_FC_FP16_1, 2, 3, "fully_connected_gpu_bfyx_ref" }, + fully_connected_test_params{ CASE_FC_FP16_2, 2, 3, "fully_connected_gpu_bfyx_ref" }, + fully_connected_test_params{ CASE_FC_FP16_3, 2, 3, "fully_connected_gpu_bfyx_ref" }, + fully_connected_test_params{ CASE_FC_FP16_3D_1, 2, 3, "fully_connected_gpu_bfyx_ref" }, + fully_connected_test_params{ CASE_FC_FP16_3D_2, 2, 3, "fully_connected_gpu_bfyx_ref" }, +})); + +class fc_fp16_eltwise_prod : public FullyConnectedFusingTestOneDNN {}; +TEST_P(fc_fp16_eltwise_prod, basic) { + auto p = GetParam(); + create_topologies( + input_layout("input", get_input_layout(p)), + data("weights", get_mem(get_weights_layout(p))), + data("bias", get_mem(get_bias_layout(p))), + data("eltwise_data", get_mem(get_per_channel_layout(p), 1, 9)), + fully_connected("fc_prim", input_info("input"), "weights", "bias", padding(), get_output_dim_size(p)), + eltwise("eltwise", { input_info("fc_prim"), input_info("eltwise_data") }, eltwise_mode::prod), + reorder("reorder_bfyx", input_info("eltwise"), p.default_format, data_types::f32) + ); + + tolerance = 1e-1f; + execute(p); +} + +INSTANTIATE_TEST_SUITE_P(fusings_gpu, fc_fp16_eltwise_prod, ::testing::ValuesIn(std::vector{ + fully_connected_test_params{ CASE_FC_FP16_1, 2, 3, "fully_connected_gpu_bfyx_ref" }, + fully_connected_test_params{ CASE_FC_FP16_2, 2, 3, "fully_connected_gpu_bfyx_ref" }, + fully_connected_test_params{ CASE_FC_FP16_3, 2, 3, "fully_connected_gpu_bfyx_ref" }, + fully_connected_test_params{ CASE_FC_FP16_3D_1, 2, 3, "fully_connected_gpu_bfyx_ref" }, + fully_connected_test_params{ CASE_FC_FP16_3D_2, 2, 3, "fully_connected_gpu_bfyx_ref" }, +})); + + #endif