[GPU] Use 4dim directly for onednn in gemm (#16182)

* [GPU] Use 4-dim directly for onednn in gemm
   We were collapsing n-dim into 3d for onednn gemm, But it is not necessary, up to 4d.

Signed-off-by: hyunback <hyunback.kim@intel.com>
This commit is contained in:
hyunback kim 2023-03-22 17:08:10 +09:00 committed by GitHub
parent 0070e8d939
commit c14e6ef48e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 20 deletions

View File

@ -301,16 +301,6 @@ public:
} }
static std::unique_ptr<primitive_impl> create(const gemm_node& arg, const kernel_impl_params& impl_params) { static std::unique_ptr<primitive_impl> create(const gemm_node& arg, const kernel_impl_params& impl_params) {
bool full_tensor_or_per_tensor = true;
for (auto prim : arg.get_fused_primitives()) {
if (prim.input_layout.is_static() && prim.output_layout.is_static()) {
full_tensor_or_per_tensor &=
prim.input_layout.count() == prim.output_layout.count() || prim.input_layout.count() == 1;
}
}
if (!full_tensor_or_per_tensor) {
IE_THROW() << "Unimplemented: per channel binary post-operation is not supported for onednn gemm. Refer PR(#15353) message.";
}
auto& engine = impl_params.prog->get_engine(); auto& engine = impl_params.prog->get_engine();
auto& config = impl_params.prog->get_config(); auto& config = impl_params.prog->get_config();
auto attr = arg.get_onednn_primitive_attributes(); auto attr = arg.get_onednn_primitive_attributes();

View File

@ -68,22 +68,25 @@ dnnl::memory::dims convert_tensor(cldnn::tensor t, size_t dims, bool is_grouped)
dnnl::memory::dims convert_gemm_tensor(cldnn::tensor t, size_t dims, bool batched_dims_can_be_removed) { dnnl::memory::dims convert_gemm_tensor(cldnn::tensor t, size_t dims, bool batched_dims_can_be_removed) {
auto sizes = t.sizes(default_fmt_for_dims(dims, false)); auto sizes = t.sizes(default_fmt_for_dims(dims, false));
dnnl::memory::dims res(sizes.begin(), sizes.end()); dnnl::memory::dims res(sizes.begin(), sizes.end());
if (dims > 3) { if (dims > 4) {
for (size_t i = 0; i < dims - 3; i++) { for (size_t i = 0; i < dims - 4; i++) {
res[i + 1] *= res[i]; res[i + 1] *= res[i];
} }
res.erase(res.begin(), res.begin() + dims - 3); res.erase(res.begin(), res.begin() + dims - 4);
} }
if (res.size() == 3 && batched_dims_can_be_removed) { if (res.size() == 4 && batched_dims_can_be_removed) {
res.erase(res.begin()); res.erase(res.begin());
} }
return res; return res;
} }
dnnl::memory::format_tag convert_gemm_data_format(dnnl::memory::dims dims) { dnnl::memory::format_tag convert_gemm_data_format(dnnl::memory::dims dims) {
if (dims.size() > 3) switch (dims.size()) {
throw std::runtime_error("[clDNN] Unsupported dims size for onednn gemm: should be <= 3"); case 2: return dnnl::memory::format_tag::ab;
return dims.size() == 3 ? dnnl::memory::format_tag::abc : dnnl::memory::format_tag::ab; case 3: return dnnl::memory::format_tag::abc;
case 4: return dnnl::memory::format_tag::abcd;
default: throw std::invalid_argument("[clDNN] Unsupported conversion from "+ std::to_string(dims.size()) + " to onednn format_tag");
}
} }

View File

@ -113,6 +113,9 @@ public:
#define CASE_GEMM_2IN_FP16_3 { { 1, 1, 64, 64 }, { 1, 1, 64, 64 } }, { 1, 1, 64, 64 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx #define CASE_GEMM_2IN_FP16_3 { { 1, 1, 64, 64 }, { 1, 1, 64, 64 } }, { 1, 1, 64, 64 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_2IN_FP16_4 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx #define CASE_GEMM_2IN_FP16_4 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_2IN_FP16_5 { { 2, 3, 2, 2 }, { 2, 3, 2, 2 } }, { 2, 3, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx #define CASE_GEMM_2IN_FP16_5 { { 2, 3, 2, 2 }, { 2, 3, 2, 2 } }, { 2, 3, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
#define CASE_GEMM_2IN_FP16_5D_1 { { 2, 3, 4, 6, 5 }, { 2, 3, 6, 4, 5 } }, { 2, 3, 6, 6, 5 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
#define CASE_GEMM_2IN_FP16_6D_1 { { 2, 3, 7, 5, 3, 2 }, { 2, 3, 5, 7, 3, 2 } }, { 2, 3, 5, 5, 3, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfwzyx, data_types::f16, format::bfwzyx
#define CASE_GEMM_2IN_U8U8_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx #define CASE_GEMM_2IN_U8U8_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_U8U8_2 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx #define CASE_GEMM_2IN_U8U8_2 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
#define CASE_GEMM_2IN_U8U8_3 { { 1, 1, 16, 32 }, { 1, 1, 32, 16 } }, { 1, 1, 32, 32 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx #define CASE_GEMM_2IN_U8U8_3 { { 1, 1, 16, 32 }, { 1, 1, 32, 16 } }, { 1, 1, 32, 32 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
@ -298,11 +301,14 @@ TEST_P(gemm_2in_add, eltwise_postop) {
add_data_size.feature[0] = 1; add_data_size.feature[0] = 1;
add_data_layout.set_tensor(add_data_size); add_data_layout.set_tensor(add_data_size);
auto in_layout0 = get_input_layout(p, 0);
auto in_layout1 = get_input_layout(p, 1);
create_topologies( create_topologies(
input_layout("input0", get_input_layout(p, 0)), input_layout("input0", in_layout0),
input_layout("input1", get_input_layout(p, 1)), input_layout("input1", in_layout1),
data("add_data", get_mem(add_data_layout, 1.0f/p.kernel.count())), data("add_data", get_mem(add_data_layout, 1.0f/p.kernel.count())),
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32), gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32, false, false, 1.f, 0.f, in_layout0.get_rank(), in_layout1.get_rank()),
eltwise("add_prim", { input_info("gemm_prim"), input_info("add_data") }, p.eltwise_m, p.default_type), eltwise("add_prim", { input_info("gemm_prim"), input_info("add_data") }, p.eltwise_m, p.default_type),
reorder("reorder_bfyx", input_info("add_prim"), p.default_format, data_types::f32) reorder("reorder_bfyx", input_info("add_prim"), p.default_format, data_types::f32)
); );
@ -318,6 +324,12 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_add, ::testing::ValuesIn(std::vec
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sum }, gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sum },
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::prod }, gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::prod },
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sub }, gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sub },
gemm_test_params{ CASE_GEMM_2IN_FP16_5D_1, 3, 4, "", dim_vec_kind::batch, eltwise_mode::sum },
gemm_test_params{ CASE_GEMM_2IN_FP16_5D_1, 3, 4, "", dim_vec_kind::batch, eltwise_mode::prod },
gemm_test_params{ CASE_GEMM_2IN_FP16_5D_1, 3, 4, "", dim_vec_kind::batch, eltwise_mode::sub },
gemm_test_params{ CASE_GEMM_2IN_FP16_6D_1, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sum },
gemm_test_params{ CASE_GEMM_2IN_FP16_6D_1, 3, 4, "", dim_vec_kind::feature, eltwise_mode::prod },
gemm_test_params{ CASE_GEMM_2IN_FP16_6D_1, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sub },
})); }));
class gemm_2in_act_scale_quantize_i8 : public GemmFusingTest {}; class gemm_2in_act_scale_quantize_i8 : public GemmFusingTest {};