[GPU] Use 4dim directly for onednn in gemm (#16182)
* [GPU] Use 4-dim directly for onednn in gemm We were collapsing n-dim into 3d for onednn gemm, But it is not necessary, up to 4d. Signed-off-by: hyunback <hyunback.kim@intel.com>
This commit is contained in:
parent
0070e8d939
commit
c14e6ef48e
@ -301,16 +301,6 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
static std::unique_ptr<primitive_impl> create(const gemm_node& arg, const kernel_impl_params& impl_params) {
|
static std::unique_ptr<primitive_impl> create(const gemm_node& arg, const kernel_impl_params& impl_params) {
|
||||||
bool full_tensor_or_per_tensor = true;
|
|
||||||
for (auto prim : arg.get_fused_primitives()) {
|
|
||||||
if (prim.input_layout.is_static() && prim.output_layout.is_static()) {
|
|
||||||
full_tensor_or_per_tensor &=
|
|
||||||
prim.input_layout.count() == prim.output_layout.count() || prim.input_layout.count() == 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!full_tensor_or_per_tensor) {
|
|
||||||
IE_THROW() << "Unimplemented: per channel binary post-operation is not supported for onednn gemm. Refer PR(#15353) message.";
|
|
||||||
}
|
|
||||||
auto& engine = impl_params.prog->get_engine();
|
auto& engine = impl_params.prog->get_engine();
|
||||||
auto& config = impl_params.prog->get_config();
|
auto& config = impl_params.prog->get_config();
|
||||||
auto attr = arg.get_onednn_primitive_attributes();
|
auto attr = arg.get_onednn_primitive_attributes();
|
||||||
|
@ -68,22 +68,25 @@ dnnl::memory::dims convert_tensor(cldnn::tensor t, size_t dims, bool is_grouped)
|
|||||||
dnnl::memory::dims convert_gemm_tensor(cldnn::tensor t, size_t dims, bool batched_dims_can_be_removed) {
|
dnnl::memory::dims convert_gemm_tensor(cldnn::tensor t, size_t dims, bool batched_dims_can_be_removed) {
|
||||||
auto sizes = t.sizes(default_fmt_for_dims(dims, false));
|
auto sizes = t.sizes(default_fmt_for_dims(dims, false));
|
||||||
dnnl::memory::dims res(sizes.begin(), sizes.end());
|
dnnl::memory::dims res(sizes.begin(), sizes.end());
|
||||||
if (dims > 3) {
|
if (dims > 4) {
|
||||||
for (size_t i = 0; i < dims - 3; i++) {
|
for (size_t i = 0; i < dims - 4; i++) {
|
||||||
res[i + 1] *= res[i];
|
res[i + 1] *= res[i];
|
||||||
}
|
}
|
||||||
res.erase(res.begin(), res.begin() + dims - 3);
|
res.erase(res.begin(), res.begin() + dims - 4);
|
||||||
}
|
}
|
||||||
if (res.size() == 3 && batched_dims_can_be_removed) {
|
if (res.size() == 4 && batched_dims_can_be_removed) {
|
||||||
res.erase(res.begin());
|
res.erase(res.begin());
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
dnnl::memory::format_tag convert_gemm_data_format(dnnl::memory::dims dims) {
|
dnnl::memory::format_tag convert_gemm_data_format(dnnl::memory::dims dims) {
|
||||||
if (dims.size() > 3)
|
switch (dims.size()) {
|
||||||
throw std::runtime_error("[clDNN] Unsupported dims size for onednn gemm: should be <= 3");
|
case 2: return dnnl::memory::format_tag::ab;
|
||||||
return dims.size() == 3 ? dnnl::memory::format_tag::abc : dnnl::memory::format_tag::ab;
|
case 3: return dnnl::memory::format_tag::abc;
|
||||||
|
case 4: return dnnl::memory::format_tag::abcd;
|
||||||
|
default: throw std::invalid_argument("[clDNN] Unsupported conversion from "+ std::to_string(dims.size()) + " to onednn format_tag");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -113,6 +113,9 @@ public:
|
|||||||
#define CASE_GEMM_2IN_FP16_3 { { 1, 1, 64, 64 }, { 1, 1, 64, 64 } }, { 1, 1, 64, 64 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
#define CASE_GEMM_2IN_FP16_3 { { 1, 1, 64, 64 }, { 1, 1, 64, 64 } }, { 1, 1, 64, 64 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||||
#define CASE_GEMM_2IN_FP16_4 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
#define CASE_GEMM_2IN_FP16_4 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||||
#define CASE_GEMM_2IN_FP16_5 { { 2, 3, 2, 2 }, { 2, 3, 2, 2 } }, { 2, 3, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
#define CASE_GEMM_2IN_FP16_5 { { 2, 3, 2, 2 }, { 2, 3, 2, 2 } }, { 2, 3, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfyx, data_types::f16, format::bfyx
|
||||||
|
#define CASE_GEMM_2IN_FP16_5D_1 { { 2, 3, 4, 6, 5 }, { 2, 3, 6, 4, 5 } }, { 2, 3, 6, 6, 5 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfzyx, data_types::f16, format::bfzyx
|
||||||
|
#define CASE_GEMM_2IN_FP16_6D_1 { { 2, 3, 7, 5, 3, 2 }, { 2, 3, 5, 7, 3, 2 } }, { 2, 3, 5, 5, 3, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::f16, data_types::f16, data_types::f16, format::bfwzyx, data_types::f16, format::bfwzyx
|
||||||
|
|
||||||
#define CASE_GEMM_2IN_U8U8_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
#define CASE_GEMM_2IN_U8U8_1 { { 1, 1, 2, 2 }, { 1, 1, 2, 2 } }, { 1, 1, 2, 2 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||||
#define CASE_GEMM_2IN_U8U8_2 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
#define CASE_GEMM_2IN_U8U8_2 { { 1, 2, 64, 128 }, { 1, 2, 256, 64 } }, { 1, 2, 256, 128 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||||
#define CASE_GEMM_2IN_U8U8_3 { { 1, 1, 16, 32 }, { 1, 1, 32, 16 } }, { 1, 1, 32, 32 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
#define CASE_GEMM_2IN_U8U8_3 { { 1, 1, 16, 32 }, { 1, 1, 32, 16 } }, { 1, 1, 32, 32 }, tensor{ 1 }, tensor{ 0 }, data_types::u8, data_types::u8, data_types::u8, format::bfyx, data_types::f32, format::bfyx
|
||||||
@ -298,11 +301,14 @@ TEST_P(gemm_2in_add, eltwise_postop) {
|
|||||||
add_data_size.feature[0] = 1;
|
add_data_size.feature[0] = 1;
|
||||||
add_data_layout.set_tensor(add_data_size);
|
add_data_layout.set_tensor(add_data_size);
|
||||||
|
|
||||||
|
auto in_layout0 = get_input_layout(p, 0);
|
||||||
|
auto in_layout1 = get_input_layout(p, 1);
|
||||||
|
|
||||||
create_topologies(
|
create_topologies(
|
||||||
input_layout("input0", get_input_layout(p, 0)),
|
input_layout("input0", in_layout0),
|
||||||
input_layout("input1", get_input_layout(p, 1)),
|
input_layout("input1", in_layout1),
|
||||||
data("add_data", get_mem(add_data_layout, 1.0f/p.kernel.count())),
|
data("add_data", get_mem(add_data_layout, 1.0f/p.kernel.count())),
|
||||||
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32),
|
gemm("gemm_prim", { input_info("input0"), input_info("input1") }, data_types::f32, false, false, 1.f, 0.f, in_layout0.get_rank(), in_layout1.get_rank()),
|
||||||
eltwise("add_prim", { input_info("gemm_prim"), input_info("add_data") }, p.eltwise_m, p.default_type),
|
eltwise("add_prim", { input_info("gemm_prim"), input_info("add_data") }, p.eltwise_m, p.default_type),
|
||||||
reorder("reorder_bfyx", input_info("add_prim"), p.default_format, data_types::f32)
|
reorder("reorder_bfyx", input_info("add_prim"), p.default_format, data_types::f32)
|
||||||
);
|
);
|
||||||
@ -318,6 +324,12 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, gemm_2in_add, ::testing::ValuesIn(std::vec
|
|||||||
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sum },
|
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sum },
|
||||||
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::prod },
|
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::prod },
|
||||||
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sub },
|
gemm_test_params{ CASE_GEMM_2IN_FP16_5, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sub },
|
||||||
|
gemm_test_params{ CASE_GEMM_2IN_FP16_5D_1, 3, 4, "", dim_vec_kind::batch, eltwise_mode::sum },
|
||||||
|
gemm_test_params{ CASE_GEMM_2IN_FP16_5D_1, 3, 4, "", dim_vec_kind::batch, eltwise_mode::prod },
|
||||||
|
gemm_test_params{ CASE_GEMM_2IN_FP16_5D_1, 3, 4, "", dim_vec_kind::batch, eltwise_mode::sub },
|
||||||
|
gemm_test_params{ CASE_GEMM_2IN_FP16_6D_1, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sum },
|
||||||
|
gemm_test_params{ CASE_GEMM_2IN_FP16_6D_1, 3, 4, "", dim_vec_kind::feature, eltwise_mode::prod },
|
||||||
|
gemm_test_params{ CASE_GEMM_2IN_FP16_6D_1, 3, 4, "", dim_vec_kind::feature, eltwise_mode::sub },
|
||||||
}));
|
}));
|
||||||
|
|
||||||
class gemm_2in_act_scale_quantize_i8 : public GemmFusingTest {};
|
class gemm_2in_act_scale_quantize_i8 : public GemmFusingTest {};
|
||||||
|
Loading…
Reference in New Issue
Block a user