From c14e6ef48e4870a19b95519ff9ba65c6c86fcc4c Mon Sep 17 00:00:00 2001 From: hyunback kim Date: Wed, 22 Mar 2023 17:08:10 +0900 Subject: [PATCH] [GPU] Use 4dim directly for onednn in gemm (#16182) * [GPU] Use 4-dim directly for onednn in gemm We were collapsing n-dim into 3d for onednn gemm, But it is not necessary, up to 4d. Signed-off-by: hyunback --- .../src/graph/impls/onednn/gemm_onednn.cpp | 10 ---------- .../intel_gpu/src/graph/impls/onednn/utils.cpp | 17 ++++++++++------- .../tests/fusions/gemm_fusion_test.cpp | 18 +++++++++++++++--- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/onednn/gemm_onednn.cpp b/src/plugins/intel_gpu/src/graph/impls/onednn/gemm_onednn.cpp index 84bcdd83d2e..309a4e24285 100644 --- a/src/plugins/intel_gpu/src/graph/impls/onednn/gemm_onednn.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/onednn/gemm_onednn.cpp @@ -301,16 +301,6 @@ public: } static std::unique_ptr create(const gemm_node& arg, const kernel_impl_params& impl_params) { - bool full_tensor_or_per_tensor = true; - for (auto prim : arg.get_fused_primitives()) { - if (prim.input_layout.is_static() && prim.output_layout.is_static()) { - full_tensor_or_per_tensor &= - prim.input_layout.count() == prim.output_layout.count() || prim.input_layout.count() == 1; - } - } - if (!full_tensor_or_per_tensor) { - IE_THROW() << "Unimplemented: per channel binary post-operation is not supported for onednn gemm. Refer PR(#15353) message."; - } auto& engine = impl_params.prog->get_engine(); auto& config = impl_params.prog->get_config(); auto attr = arg.get_onednn_primitive_attributes(); diff --git a/src/plugins/intel_gpu/src/graph/impls/onednn/utils.cpp b/src/plugins/intel_gpu/src/graph/impls/onednn/utils.cpp index 6b217b196c9..09e977b5edc 100644 --- a/src/plugins/intel_gpu/src/graph/impls/onednn/utils.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/onednn/utils.cpp @@ -68,22 +68,25 @@ dnnl::memory::dims convert_tensor(cldnn::tensor t, size_t dims, bool is_grouped) dnnl::memory::dims convert_gemm_tensor(cldnn::tensor t, size_t dims, bool batched_dims_can_be_removed) { auto sizes = t.sizes(default_fmt_for_dims(dims, false)); dnnl::memory::dims res(sizes.begin(), sizes.end()); - if (dims > 3) { - for (size_t i = 0; i < dims - 3; i++) { + if (dims > 4) { + for (size_t i = 0; i < dims - 4; i++) { res[i + 1] *= res[i]; } - res.erase(res.begin(), res.begin() + dims - 3); + res.erase(res.begin(), res.begin() + dims - 4); } - if (res.size() == 3 && batched_dims_can_be_removed) { + if (res.size() == 4 && batched_dims_can_be_removed) { res.erase(res.begin()); } return res; } dnnl::memory::format_tag convert_gemm_data_format(dnnl::memory::dims dims) { - if (dims.size() > 3) - throw std::runtime_error("[clDNN] Unsupported dims size for onednn gemm: should be <= 3"); - return dims.size() == 3 ? dnnl::memory::format_tag::abc : dnnl::memory::format_tag::ab; + switch (dims.size()) { + case 2: return dnnl::memory::format_tag::ab; + case 3: return dnnl::memory::format_tag::abc; + case 4: return dnnl::memory::format_tag::abcd; + default: throw std::invalid_argument("[clDNN] Unsupported conversion from "+ std::to_string(dims.size()) + " to onednn format_tag"); + } } diff --git a/src/plugins/intel_gpu/tests/fusions/gemm_fusion_test.cpp b/src/plugins/intel_gpu/tests/fusions/gemm_fusion_test.cpp index 34b35f26c05..847c9192dd8 100644 --- a/src/plugins/intel_gpu/tests/fusions/gemm_fusion_test.cpp +++ b/src/plugins/intel_gpu/tests/fusions/gemm_fusion_test.cpp @@ -113,6 +113,9 @@ public: #define CASE_GEMM_2IN_FP16_3 { { 1, 1, 64, 64 }, { 1, 1, 64, 64 } }, { 1, 1, 64, 64 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx #define CASE_GEMM_2IN_FP16_4 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx #define CASE_GEMM_2IN_FP16_5 { { 2, 3, 2, 2 }, { 2, 3, 2, 2 } }, { 2, 3, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx +#define CASE_GEMM_2IN_FP16_5D_1 { { 2, 3, 4, 6, 5 }, { 2, 3, 6, 4, 5 } }, { 2, 3, 6, 6, 5 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx +#define CASE_GEMM_2IN_FP16_6D_1 { { 2, 3, 7, 5, 3, 2 }, { 2, 3, 5, 7, 3, 2 } }, { 2, 3, 5, 5, 3, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfwzyx, data_types::f16, format::bfwzyx + #define CASE_GEMM_2IN_U8U8_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx #define CASE_GEMM_2IN_U8U8_2 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx #define CASE_GEMM_2IN_U8U8_3 { { 1, 1, 16, 32 }, { 1, 1, 32, 16 } }, { 1, 1, 32, 32 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx @@ -298,11 +301,14 @@ TEST_P(gemm_2in_add, eltwise_postop) { add_data_size.feature[0] = 1; add_data_layout.set_tensor(add_data_size); + auto in_layout0 = get_input_layout(p, 0); + auto in_layout1 = get_input_layout(p, 1); + create_topologies( - input_layout("input0", get_input_layout(p, 0)), - input_layout("input1", get_input_layout(p, 1)), + input_layout("input0", in_layout0), + input_layout("input1", in_layout1), data("add_data", get_mem(add_data_layout, 1.0f/p.kernel.count())), - gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32), + gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32, false, false, 1.f, 0.f, in_layout0.get_rank(), in_layout1.get_rank()), eltwise("add_prim", { input_info("gemm_prim"), input_info("add_data") }, p.eltwise_m, p.default_type), reorder("reorder_bfyx", input_info("add_prim"), p.default_format, data_types::f32) ); @@ -318,6 +324,12 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_add, ::testing::ValuesIn(std::vec gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sum }, gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::prod }, gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sub }, + gemm_test_params{ CASE_GEMM_2IN_FP16_5D_1, 3, 4, "", dim_vec_kind::batch, eltwise_mode::sum }, + gemm_test_params{ CASE_GEMM_2IN_FP16_5D_1, 3, 4, "", dim_vec_kind::batch, eltwise_mode::prod }, + gemm_test_params{ CASE_GEMM_2IN_FP16_5D_1, 3, 4, "", dim_vec_kind::batch, eltwise_mode::sub }, + gemm_test_params{ CASE_GEMM_2IN_FP16_6D_1, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sum }, + gemm_test_params{ CASE_GEMM_2IN_FP16_6D_1, 3, 4, "", dim_vec_kind::feature, eltwise_mode::prod }, + gemm_test_params{ CASE_GEMM_2IN_FP16_6D_1, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sub }, })); class gemm_2in_act_scale_quantize_i8 : public GemmFusingTest {};